From 8af7ad9bb948394445a8d4dbf3b561a5915b513d Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 26 Jun 2022 10:45:44 -0400 Subject: [PATCH] Formatter, take 2 --- .JuliaFormatter.toml | 1 + .github/workflows/FormatCheck.yml | 42 + docs/make.jl | 29 +- docs/pages.jl | 96 +- src/adjoint_common.jl | 870 +++---- src/backsolve_adjoint.jl | 717 +++--- src/callback_tracking.jl | 671 ++--- src/concrete_solve.jl | 2019 ++++++++------- src/derivative_wrappers.jl | 1478 +++++------ src/forward_sensitivity.jl | 695 +++--- src/hasbranching.jl | 66 +- src/interpolating_adjoint.jl | 985 ++++---- src/lss.jl | 1095 ++++---- src/nilsas.jl | 738 +++--- src/nilss.jl | 1008 ++++---- src/quadrature_adjoint.jl | 625 ++--- src/reversediff.jl | 89 +- src/sde_tools.jl | 131 +- src/second_order.jl | 35 +- src/sensitivity_algorithms.jl | 2201 +++++++++-------- src/sensitivity_interface.jl | 959 +++---- src/steadystate_adjoint.jl | 170 +- src/tracker.jl | 79 +- src/zygote.jl | 21 +- test/HybridNODE.jl | 167 +- test/adjoint.jl | 1251 +++++----- test/adjoint_param.jl | 96 +- test/adjoint_shapes.jl | 53 +- test/alternative_ad_frontend.jl | 212 +- test/array_partitions.jl | 81 +- test/autodiff_events.jl | 76 +- test/branching_derivatives.jl | 81 +- test/callback_reversediff.jl | 55 +- test/callbacks/SDE_callbacks.jl | 73 +- test/callbacks/continuous_callbacks.jl | 411 +-- test/callbacks/continuous_vs_discrete.jl | 301 ++- test/callbacks/discrete_callbacks.jl | 502 ++-- .../callbacks/forward_sensitivity_callback.jl | 75 +- test/callbacks/vector_continuous_callbacks.jl | 81 +- test/complex_adjoints.jl | 139 +- test/complex_no_u.jl | 15 +- test/concrete_solve_derivatives.jl | 727 ++++-- test/derivative_shapes.jl | 74 +- test/discrete.jl | 30 +- test/distributed.jl | 25 +- test/ensembles.jl | 53 +- test/forward.jl | 218 +- test/forward_chunking.jl | 222 +- test/forward_prob_kwargs.jl | 29 +- test/forward_remake.jl | 33 +- ...warddiffsensitivity_sparsity_components.jl | 41 +- test/gdp_regression_test.jl | 93 +- test/gpu/diffeqflux_standard_gpu.jl | 79 +- test/gpu/mixed_gpu_cpu_adjoint.jl | 42 +- test/hasbranching.jl | 4 +- test/hybrid_de.jl | 55 +- test/layers.jl | 42 +- test/layers_dde.jl | 30 +- test/layers_sde.jl | 72 +- test/literal_adjoint.jl | 26 +- test/mixed_costs.jl | 306 ++- test/null_parameters.jl | 31 +- test/parameter_compatibility_errors.jl | 37 +- test/partial_neural.jl | 76 +- test/prob_kwargs.jl | 60 +- test/rode.jl | 957 +++---- test/runtests.jl | 215 +- test/save_idxs.jl | 63 +- test/sde_checkpointing.jl | 79 +- test/sde_neural.jl | 139 +- test/sde_nondiag_stratonovich.jl | 829 ++++--- test/sde_scalar_ito.jl | 149 +- test/sde_scalar_stratonovich.jl | 274 +- test/sde_stratonovich.jl | 768 +++--- test/sde_transformation_test.jl | 385 ++- test/second_order.jl | 67 +- test/second_order_odes.jl | 64 +- test/shadowing.jl | 985 ++++---- test/size_handling_adjoint.jl | 24 +- test/sparse_adjoint.jl | 49 +- test/steady_state.jl | 612 +++-- test/stiff_adjoints.jl | 418 ++-- test/time_type_mixing.jl | 22 +- 83 files changed, 14333 insertions(+), 12660 deletions(-) create mode 100644 .JuliaFormatter.toml create mode 100644 .github/workflows/FormatCheck.yml diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 000000000..453925c3f --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "sciml" \ No newline at end of file diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml new file mode 100644 index 000000000..2a3517a0f --- /dev/null +++ b/.github/workflows/FormatCheck.yml @@ -0,0 +1,42 @@ +name: format-check + +on: + push: + branches: + - 'master' + - 'release-' + tags: '*' + pull_request: + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + julia-version: [1] + julia-arch: [x86] + os: [ubuntu-latest] + steps: + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + + - uses: actions/checkout@v1 + - name: Install JuliaFormatter and format + # This will use the latest version by default but you can set the version like so: + # + # julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))' + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".", verbose=true)' + - name: Format check + run: | + julia -e ' + out = Cmd(`git diff --name-only`) |> read |> String + if out == "" + exit(0) + else + @error "Some files have not been formatted !!!" + write(stdout, out) + exit(1) + end' diff --git a/docs/make.jl b/docs/make.jl index 36713de67..0c423040c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -6,14 +6,12 @@ using Plots include("pages.jl") -makedocs( - sitename = "SciMLSensitivity.jl", - authors="Chris Rackauckas et al.", - clean = true, - doctest = false, - modules = [SciMLSensitivity], - - strict = [ +makedocs(sitename = "SciMLSensitivity.jl", + authors = "Chris Rackauckas et al.", + clean = true, + doctest = false, + modules = [SciMLSensitivity], + strict = [ :doctest, :linkcheck, :parse_error, @@ -21,14 +19,9 @@ makedocs( # Other available options are # :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block ], + format = Documenter.HTML(assets = ["assets/favicon.ico"], + canonical = "https://sensitivity.sciml.ai/stable/"), + pages = pages) - format = Documenter.HTML(#analytics = "", - assets = ["assets/favicon.ico"], - canonical="https://sensitivity.sciml.ai/stable/"), - pages=pages -) - -deploydocs( - repo = "github.com/SciML/SciMLSensitivity.jl.git"; - push_preview = true -) +deploydocs(repo = "github.com/SciML/SciMLSensitivity.jl.git"; + push_preview = true) diff --git a/docs/pages.jl b/docs/pages.jl index 68ddaf74f..37d4091a0 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -1,62 +1,36 @@ pages = [ - "SciMLSensitivity.jl: Automatic Differentiation and Adjoints for (Differential) Equation Solvers" => "index.md", - "Tutorials" => Any[ - "Differentiating Ordinary Differential Equations (ODE) Tutorials" => Any[ - "ad_examples/differentiating_ode.md", - "ad_examples/direct_sensitivity.md", - "ad_examples/adjoint_continuous_functional.md", - "ad_examples/chaotic_ode.md", - ], - "Fitting Ordinary Differential Equation (ODE) Tutorials" => Any[ - "ode_fitting/optimization_ode.md", - "ode_fitting/stiff_ode_fit.md", - "ode_fitting/exogenous_input.md", - "ode_fitting/data_parallel.md", - "ode_fitting/prediction_error_method.md", - "ode_fitting/second_order_adjoints.md", - "ode_fitting/second_order_neural.md", - ], - "Training Techniques and Tips" => Any[ - "training_tips/local_minima.md", - "training_tips/divergence.md", - "training_tips/multiple_nn.md", - ], - "Neural Ordinary Differential Equation (Neural ODE) Tutorials" => Any[ - "neural_ode/neural_ode_flux.md", - "neural_ode/neural_gde.md", - "neural_ode/minibatch.md", - ], - "Stochastic Differential Equation (SDE) Tutorials" => Any[ - "sde_fitting/optimization_sde.md", - ], - "Delay Differential Equation (DDE) Tutorials" => Any[ - "dde_fitting/delay_diffeq.md", - ], - "Differential-Algebraic Equation (DAE) Tutorials" => Any[ - "dae_fitting/physical_constraints.md", - ], - "Partial Differential Equation (PDE) Tutorials" => Any[ - "pde_fitting/pde_constrained.md", - ], - "Hybrid and Jump Equation Tutorials" => Any[ - "hybrid_jump_fitting/hybrid_diffeq.md", - "hybrid_jump_fitting/bouncing_ball.md", - ], - "Bayesian Estimation Tutorials" => Any[ - "bayesian/turing_bayesian.md", - ], - "Optimal and Model Predictive Control Tutorials" => Any[ - "optimal_control/optimal_control.md", - "optimal_control/feedback_control.md", - "optimal_control/SDE_control.md", - ], - ], - "Manual and APIs" => Any[ - "manual/differential_equation_sensitivities.md", - "manual/nonlinear_solve_sensitivities.md", - "manual/direct_forward_sensitivity.md", - "manual/direct_adjoint_sensitivities.md", - ], - "Benchmarks" => "Benchmark.md", - "Sensitivity Math Details" => "sensitivity_math.md", - ] \ No newline at end of file + "SciMLSensitivity.jl: Automatic Differentiation and Adjoints for (Differential) Equation Solvers" => "index.md", + "Tutorials" => Any["Differentiating Ordinary Differential Equations (ODE) Tutorials" => Any["ad_examples/differentiating_ode.md", + "ad_examples/direct_sensitivity.md", + "ad_examples/adjoint_continuous_functional.md", + "ad_examples/chaotic_ode.md"], + "Fitting Ordinary Differential Equation (ODE) Tutorials" => Any["ode_fitting/optimization_ode.md", + "ode_fitting/stiff_ode_fit.md", + "ode_fitting/exogenous_input.md", + "ode_fitting/data_parallel.md", + "ode_fitting/prediction_error_method.md", + "ode_fitting/second_order_adjoints.md", + "ode_fitting/second_order_neural.md"], + "Training Techniques and Tips" => Any["training_tips/local_minima.md", + "training_tips/divergence.md", + "training_tips/multiple_nn.md"], + "Neural Ordinary Differential Equation (Neural ODE) Tutorials" => Any["neural_ode/neural_ode_flux.md", + "neural_ode/neural_gde.md", + "neural_ode/minibatch.md"], + "Stochastic Differential Equation (SDE) Tutorials" => Any["sde_fitting/optimization_sde.md"], + "Delay Differential Equation (DDE) Tutorials" => Any["dde_fitting/delay_diffeq.md"], + "Differential-Algebraic Equation (DAE) Tutorials" => Any["dae_fitting/physical_constraints.md"], + "Partial Differential Equation (PDE) Tutorials" => Any["pde_fitting/pde_constrained.md"], + "Hybrid and Jump Equation Tutorials" => Any["hybrid_jump_fitting/hybrid_diffeq.md", + "hybrid_jump_fitting/bouncing_ball.md"], + "Bayesian Estimation Tutorials" => Any["bayesian/turing_bayesian.md"], + "Optimal and Model Predictive Control Tutorials" => Any["optimal_control/optimal_control.md", + "optimal_control/feedback_control.md", + "optimal_control/SDE_control.md"]], + "Manual and APIs" => Any["manual/differential_equation_sensitivities.md", + "manual/nonlinear_solve_sensitivities.md", + "manual/direct_forward_sensitivity.md", + "manual/direct_adjoint_sensitivities.md"], + "Benchmarks" => "Benchmark.md", + "Sensitivity Math Details" => "sensitivity_math.md", +] diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 7a36da56e..57225ba53 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -1,21 +1,22 @@ -struct AdjointDiffCache{UF,PF,G,TJ,PJT,uType,JC,GC,PJC,JNC,PJNC,rateType,DG,DI,AI,FM} - uf::UF - pf::PF - g::G - J::TJ - pJ::PJT - dg_val::uType - jac_config::JC - g_grad_config::GC - paramjac_config::PJC - jac_noise_config::JNC - paramjac_noise_config::PJNC - f_cache::rateType - dg::DG - diffvar_idxs::DI - algevar_idxs::AI - factorized_mass_matrix::FM - issemiexplicitdae::Bool +struct AdjointDiffCache{UF, PF, G, TJ, PJT, uType, JC, GC, PJC, JNC, PJNC, rateType, DG, DI, + AI, FM} + uf::UF + pf::PF + g::G + J::TJ + pJ::PJT + dg_val::uType + jac_config::JC + g_grad_config::GC + paramjac_config::PJC + jac_noise_config::JNC + paramjac_noise_config::PJNC + f_cache::rateType + dg::DG + diffvar_idxs::DI + algevar_idxs::AI + factorized_mass_matrix::FM + issemiexplicitdae::Bool end """ @@ -23,473 +24,494 @@ end return (AdjointDiffCache, y) """ -function adjointdiffcache(g::G,sensealg,discrete,sol,dg::DG,f;quad=false,noiseterm=false,needs_jac=false) where {G,DG} - prob = sol.prob - if prob isa DiffEqBase.SteadyStateProblem - @unpack u0, p = prob - tspan = (nothing, nothing) - #elseif prob isa SDEProblem - # @unpack tspan, u0, p = prob - else - @unpack u0, p, tspan = prob - end - numparams = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(p) - numindvar = length(u0) - isautojacvec = get_jacvec(sensealg) - - issemiexplicitdae = false - mass_matrix = sol.prob.f.mass_matrix - if mass_matrix isa UniformScaling - factorized_mass_matrix = mass_matrix' - elseif mass_matrix isa Tuple{UniformScaling,UniformScaling} - factorized_mass_matrix = (I',I') - else - mass_matrix = mass_matrix' - diffvar_idxs = findall(x->any(!iszero, @view(mass_matrix[:, x])), axes(mass_matrix, 2)) - algevar_idxs = setdiff(eachindex(u0), diffvar_idxs) - # TODO: operator - M̃ = @view mass_matrix[diffvar_idxs, diffvar_idxs] - factorized_mass_matrix = lu(M̃, check=false) - issuccess(factorized_mass_matrix) || error("The submatrix corresponding to the differential variables of the mass matrix must be nonsingular!") - isempty(algevar_idxs) || (issemiexplicitdae = true) - end - if !issemiexplicitdae - diffvar_idxs = eachindex(u0) - algevar_idxs = 1:0 - end - - if !needs_jac - J = (issemiexplicitdae || !isautojacvec) ? similar(u0, numindvar, numindvar) : nothing - else - # Force construction of the Jacobian - J = similar(u0, numindvar, numindvar) - end - - if !discrete - if dg !== nothing - pg = nothing - pg_config = nothing - if dg isa Tuple && length(dg) == 2 - dg_val = (similar(u0, numindvar),similar(u0, numparams)) - dg_val[1] .= false - dg_val[2] .= false - else - dg_val = similar(u0, numindvar) # number of funcs size - dg_val .= false - end +function adjointdiffcache(g::G, sensealg, discrete, sol, dg::DG, f; quad = false, + noiseterm = false, needs_jac = false) where {G, DG} + prob = sol.prob + if prob isa DiffEqBase.SteadyStateProblem + @unpack u0, p = prob + tspan = (nothing, nothing) + #elseif prob isa SDEProblem + # @unpack tspan, u0, p = prob else - if !(prob isa RODEProblem) - pg = UGradientWrapper(g,tspan[2],p) - else - pg = RODEUGradientWrapper(g,tspan[2],p,last(sol.W)) - end - pg_config = build_grad_config(sensealg,pg,u0,p) - dg_val = similar(u0, numindvar) # number of funcs size - dg_val .= false + @unpack u0, p, tspan = prob end - else - dg_val = nothing - pg = nothing - pg_config = nothing - end - - if DiffEqBase.has_jac(f) || (J === nothing) - jac_config = nothing - uf = nothing - else - if DiffEqBase.isinplace(prob) - if !(prob isa RODEProblem) - uf = DiffEqBase.UJacobianWrapper(f,tspan[2],p) - else - uf = RODEUJacobianWrapper(f,tspan[2],p,last(sol.W)) - end - jac_config = build_jac_config(sensealg,uf,u0) + numparams = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(p) + numindvar = length(u0) + isautojacvec = get_jacvec(sensealg) + + issemiexplicitdae = false + mass_matrix = sol.prob.f.mass_matrix + if mass_matrix isa UniformScaling + factorized_mass_matrix = mass_matrix' + elseif mass_matrix isa Tuple{UniformScaling, UniformScaling} + factorized_mass_matrix = (I', I') else - if !(prob isa RODEProblem) - uf = DiffEqBase.UDerivativeWrapper(f,tspan[2],p) - else - uf = RODEUDerivativeWrapper(f,tspan[2],p,last(sol.W)) - end - jac_config = nothing + mass_matrix = mass_matrix' + diffvar_idxs = findall(x -> any(!iszero, @view(mass_matrix[:, x])), + axes(mass_matrix, 2)) + algevar_idxs = setdiff(eachindex(u0), diffvar_idxs) + # TODO: operator + M̃ = @view mass_matrix[diffvar_idxs, diffvar_idxs] + factorized_mass_matrix = lu(M̃, check = false) + issuccess(factorized_mass_matrix) || + error("The submatrix corresponding to the differential variables of the mass matrix must be nonsingular!") + isempty(algevar_idxs) || (issemiexplicitdae = true) + end + if !issemiexplicitdae + diffvar_idxs = eachindex(u0) + algevar_idxs = 1:0 end - end - - if prob isa DiffEqBase.SteadyStateProblem - y = copy(sol.u) - else - y = copy(sol.u[end]) - end - - if typeof(prob.p) <: DiffEqBase.NullParameters - _p = similar(y,(0,)) - else - _p = prob.p - end - - @assert sensealg.autojacvec !== nothing - if sensealg.autojacvec isa ReverseDiffVJP - if prob isa DiffEqBase.SteadyStateProblem - if DiffEqBase.isinplace(prob) - tape = ReverseDiff.GradientTape((y, _p)) do u,p - du1 = p !== nothing && p !== DiffEqBase.NullParameters() ? similar(p, size(u)) : similar(u) - f(du1,u,p,nothing) - return vec(du1) - end - else - tape = ReverseDiff.GradientTape((y, _p)) do u,p - vec(f(u,p,nothing)) - end - end - elseif noiseterm && (!StochasticDiffEq.is_diagonal_noise(prob) || isnoisemixing(sensealg)) - tape = nothing + if !needs_jac + J = (issemiexplicitdae || !isautojacvec) ? similar(u0, numindvar, numindvar) : + nothing else - if DiffEqBase.isinplace(prob) - if !(prob isa RODEProblem) - tape = ReverseDiff.GradientTape((y, _p, [tspan[2]])) do u,p,t - du1 = (p !== nothing && p !== DiffEqBase.NullParameters()) ? similar(p, size(u)) : similar(u) - f(du1,u,p,first(t)) - return vec(du1) - end + # Force construction of the Jacobian + J = similar(u0, numindvar, numindvar) + end + + if !discrete + if dg !== nothing + pg = nothing + pg_config = nothing + if dg isa Tuple && length(dg) == 2 + dg_val = (similar(u0, numindvar), similar(u0, numparams)) + dg_val[1] .= false + dg_val[2] .= false + else + dg_val = similar(u0, numindvar) # number of funcs size + dg_val .= false + end else - tape = ReverseDiff.GradientTape((y, _p, [tspan[2]],last(sol.W))) do u,p,t,W - du1 = p !== nothing && p !== DiffEqBase.NullParameters() ? similar(p, size(u)) : similar(u) - f(du1,u,p,first(t),W) - return vec(du1) - end + if !(prob isa RODEProblem) + pg = UGradientWrapper(g, tspan[2], p) + else + pg = RODEUGradientWrapper(g, tspan[2], p, last(sol.W)) + end + pg_config = build_grad_config(sensealg, pg, u0, p) + dg_val = similar(u0, numindvar) # number of funcs size + dg_val .= false end - else - if !(prob isa RODEProblem) - tape = ReverseDiff.GradientTape((y, _p, [tspan[2]])) do u,p,t - vec(f(u,p,first(t))) - end + else + dg_val = nothing + pg = nothing + pg_config = nothing + end + + if DiffEqBase.has_jac(f) || (J === nothing) + jac_config = nothing + uf = nothing + else + if DiffEqBase.isinplace(prob) + if !(prob isa RODEProblem) + uf = DiffEqBase.UJacobianWrapper(f, tspan[2], p) + else + uf = RODEUJacobianWrapper(f, tspan[2], p, last(sol.W)) + end + jac_config = build_jac_config(sensealg, uf, u0) else - tape = ReverseDiff.GradientTape((y, _p, [tspan[2]],last(sol.W))) do u,p,t,W - return f(u,p,first(t),W) - end + if !(prob isa RODEProblem) + uf = DiffEqBase.UDerivativeWrapper(f, tspan[2], p) + else + uf = RODEUDerivativeWrapper(f, tspan[2], p, last(sol.W)) + end + jac_config = nothing end - end end - if compile_tape(sensealg.autojacvec) - paramjac_config = ReverseDiff.compile(tape) + if prob isa DiffEqBase.SteadyStateProblem + y = copy(sol.u) else - paramjac_config = tape + y = copy(sol.u[end]) end - pf = nothing - elseif sensealg.autojacvec isa EnzymeVJP if typeof(prob.p) <: DiffEqBase.NullParameters - paramjac_config = zero(y),prob.p,zero(y),zero(y) + _p = similar(y, (0,)) else - paramjac_config = zero(y),zero(_p),zero(y),zero(y) + _p = prob.p end - pf = let f = f.f - if DiffEqBase.isinplace(prob) && prob isa RODEProblem - function (out,u,_p,t,W) - f(out, u, _p, t, W) - nothing - end - elseif DiffEqBase.isinplace(prob) - function (out,u,_p,t) - f(out, u, _p, t) - nothing + + @assert sensealg.autojacvec !== nothing + + if sensealg.autojacvec isa ReverseDiffVJP + if prob isa DiffEqBase.SteadyStateProblem + if DiffEqBase.isinplace(prob) + tape = ReverseDiff.GradientTape((y, _p)) do u, p + du1 = p !== nothing && p !== DiffEqBase.NullParameters() ? + similar(p, size(u)) : similar(u) + f(du1, u, p, nothing) + return vec(du1) + end + else + tape = ReverseDiff.GradientTape((y, _p)) do u, p + vec(f(u, p, nothing)) + end end - elseif !DiffEqBase.isinplace(prob) && prob isa RODEProblem - function (out,u,_p,t,W) - out .= f(u, _p, t, W) - nothing + elseif noiseterm && + (!StochasticDiffEq.is_diagonal_noise(prob) || isnoisemixing(sensealg)) + tape = nothing + else + if DiffEqBase.isinplace(prob) + if !(prob isa RODEProblem) + tape = ReverseDiff.GradientTape((y, _p, [tspan[2]])) do u, p, t + du1 = (p !== nothing && p !== DiffEqBase.NullParameters()) ? + similar(p, size(u)) : similar(u) + f(du1, u, p, first(t)) + return vec(du1) + end + else + tape = ReverseDiff.GradientTape((y, _p, [tspan[2]], last(sol.W))) do u, + p, + t, + W + du1 = p !== nothing && p !== DiffEqBase.NullParameters() ? + similar(p, size(u)) : similar(u) + f(du1, u, p, first(t), W) + return vec(du1) + end + end + else + if !(prob isa RODEProblem) + tape = ReverseDiff.GradientTape((y, _p, [tspan[2]])) do u, p, t + vec(f(u, p, first(t))) + end + else + tape = ReverseDiff.GradientTape((y, _p, [tspan[2]], last(sol.W))) do u, + p, + t, + W + return f(u, p, first(t), W) + end + end end - else !DiffEqBase.isinplace(prob) - function (out,u,_p,t) - out .= f(u, _p, t) - nothing + end + + if compile_tape(sensealg.autojacvec) + paramjac_config = ReverseDiff.compile(tape) + else + paramjac_config = tape + end + + pf = nothing + elseif sensealg.autojacvec isa EnzymeVJP + if typeof(prob.p) <: DiffEqBase.NullParameters + paramjac_config = zero(y), prob.p, zero(y), zero(y) + else + paramjac_config = zero(y), zero(_p), zero(y), zero(y) + end + pf = let f = f.f + if DiffEqBase.isinplace(prob) && prob isa RODEProblem + function (out, u, _p, t, W) + f(out, u, _p, t, W) + nothing + end + elseif DiffEqBase.isinplace(prob) + function (out, u, _p, t) + f(out, u, _p, t) + nothing + end + elseif !DiffEqBase.isinplace(prob) && prob isa RODEProblem + function (out, u, _p, t, W) + out .= f(u, _p, t, W) + nothing + end + else + !DiffEqBase.isinplace(prob) + function (out, u, _p, t) + out .= f(u, _p, t) + nothing + end end end - end - elseif DiffEqBase.has_paramjac(f) || isautojacvec || quad || sensealg.autojacvec isa EnzymeVJP - paramjac_config = nothing - pf = nothing - else - if DiffEqBase.isinplace(prob) - if !(prob isa RODEProblem) - pf = DiffEqBase.ParamJacobianWrapper(f,tspan[1],y) - else - pf = RODEParamJacobianWrapper(f,tspan[1],y,last(sol.W)) - end - paramjac_config = build_param_jac_config(sensealg,pf,y,p) + elseif DiffEqBase.has_paramjac(f) || isautojacvec || quad || + sensealg.autojacvec isa EnzymeVJP + paramjac_config = nothing + pf = nothing else - if !(prob isa RODEProblem) - pf = ParamGradientWrapper(f,tspan[2],y) - else - pf = RODEParamGradientWrapper(f,tspan[2],y,last(sol.W)) - end - paramjac_config = nothing + if DiffEqBase.isinplace(prob) + if !(prob isa RODEProblem) + pf = DiffEqBase.ParamJacobianWrapper(f, tspan[1], y) + else + pf = RODEParamJacobianWrapper(f, tspan[1], y, last(sol.W)) + end + paramjac_config = build_param_jac_config(sensealg, pf, y, p) + else + if !(prob isa RODEProblem) + pf = ParamGradientWrapper(f, tspan[2], y) + else + pf = RODEParamGradientWrapper(f, tspan[2], y, last(sol.W)) + end + paramjac_config = nothing + end end - end - pJ = (quad || isautojacvec) ? nothing : similar(u0, numindvar, numparams) - - f_cache = DiffEqBase.isinplace(prob) ? deepcopy(u0) : nothing - - if noiseterm - if sensealg.autojacvec isa ReverseDiffVJP + pJ = (quad || isautojacvec) ? nothing : similar(u0, numindvar, numparams) - jac_noise_config = nothing - paramjac_noise_config = [] + f_cache = DiffEqBase.isinplace(prob) ? deepcopy(u0) : nothing - if DiffEqBase.isinplace(prob) - for i in 1:numindvar - function noisetape(indx) - if StochasticDiffEq.is_diagonal_noise(prob) - ReverseDiff.GradientTape((y, _p, [tspan[2]])) do u,p,t - du1 = p !== nothing && p !== DiffEqBase.NullParameters() ? similar(p, size(u)) : similar(u) - f(du1,u,p,first(t)) - return du1[indx] - end + if noiseterm + if sensealg.autojacvec isa ReverseDiffVJP + jac_noise_config = nothing + paramjac_noise_config = [] + + if DiffEqBase.isinplace(prob) + for i in 1:numindvar + function noisetape(indx) + if StochasticDiffEq.is_diagonal_noise(prob) + ReverseDiff.GradientTape((y, _p, [tspan[2]])) do u, p, t + du1 = p !== nothing && p !== DiffEqBase.NullParameters() ? + similar(p, size(u)) : similar(u) + f(du1, u, p, first(t)) + return du1[indx] + end + else + ReverseDiff.GradientTape((y, _p, [tspan[2]])) do u, p, t + du1 = similar(p, size(prob.noise_rate_prototype)) + du1 .= false + f(du1, u, p, first(t)) + return du1[:, indx] + end + end + end + tapei = noisetape(i) + if compile_tape(sensealg.autojacvec) + push!(paramjac_noise_config, ReverseDiff.compile(tapei)) + else + push!(paramjac_noise_config, tapei) + end + end else - ReverseDiff.GradientTape((y, _p, [tspan[2]])) do u,p,t - du1 = similar(p, size(prob.noise_rate_prototype)) - du1 .= false - f(du1,u,p,first(t)) - return du1[:,indx] - end + for i in 1:numindvar + function noisetapeoop(indx) + if StochasticDiffEq.is_diagonal_noise(prob) + ReverseDiff.GradientTape((y, _p, [tspan[2]])) do u, p, t + f(u, p, first(t))[indx] + end + else + ReverseDiff.GradientTape((y, _p, [tspan[2]])) do u, p, t + f(u, p, first(t))[:, indx] + end + end + end + tapei = noisetapeoop(i) + if compile_tape(sensealg.autojacvec) + push!(paramjac_noise_config, ReverseDiff.compile(tapei)) + else + push!(paramjac_noise_config, tapei) + end + end + end + elseif sensealg.autojacvec isa Bool + if DiffEqBase.isinplace(prob) + if StochasticDiffEq.is_diagonal_noise(prob) + pf = DiffEqBase.ParamJacobianWrapper(f, tspan[1], y) + if isnoisemixing(sensealg) + uf = DiffEqBase.UJacobianWrapper(f, tspan[2], p) + jac_noise_config = build_jac_config(sensealg, uf, u0) + else + jac_noise_config = nothing + end + else + pf = ParamNonDiagNoiseJacobianWrapper(f, tspan[1], y, + prob.noise_rate_prototype) + uf = UNonDiagNoiseJacobianWrapper(f, tspan[2], p, + prob.noise_rate_prototype) + jac_noise_config = build_jac_config(sensealg, uf, u0) + end + paramjac_noise_config = build_param_jac_config(sensealg, pf, y, p) + else + if StochasticDiffEq.is_diagonal_noise(prob) + pf = ParamGradientWrapper(f, tspan[2], y) + if isnoisemixing(sensealg) + uf = DiffEqBase.UDerivativeWrapper(f, tspan[2], p) + end + else + pf = ParamNonDiagNoiseGradientWrapper(f, tspan[1], y) + uf = UNonDiagNoiseGradientWrapper(f, tspan[2], p) + end + paramjac_noise_config = nothing + jac_noise_config = nothing end - end - tapei = noisetape(i) - if compile_tape(sensealg.autojacvec) - push!(paramjac_noise_config, ReverseDiff.compile(tapei)) - else - push!(paramjac_noise_config, tapei) - end - end - else - for i in 1:numindvar - function noisetapeoop(indx) if StochasticDiffEq.is_diagonal_noise(prob) - ReverseDiff.GradientTape((y, _p, [tspan[2]])) do u,p,t - f(u,p,first(t))[indx] - end + pJ = similar(u0, numindvar, numparams) + if isnoisemixing(sensealg) + J = similar(u0, numindvar, numindvar) + end else - ReverseDiff.GradientTape((y, _p, [tspan[2]])) do u,p,t - f(u,p,first(t))[:,indx] - end + pJ = similar(u0, numindvar * numindvar, numparams) + J = similar(u0, numindvar * numindvar, numindvar) end - end - tapei = noisetapeoop(i) - if compile_tape(sensealg.autojacvec) - push!(paramjac_noise_config, ReverseDiff.compile(tapei)) - else - push!(paramjac_noise_config, tapei) - end - end - end - elseif sensealg.autojacvec isa Bool - if DiffEqBase.isinplace(prob) - if StochasticDiffEq.is_diagonal_noise(prob) - pf = DiffEqBase.ParamJacobianWrapper(f,tspan[1],y) - if isnoisemixing(sensealg) - uf = DiffEqBase.UJacobianWrapper(f,tspan[2],p) - jac_noise_config = build_jac_config(sensealg,uf,u0) - else - jac_noise_config = nothing - end - else - pf = ParamNonDiagNoiseJacobianWrapper(f,tspan[1],y,prob.noise_rate_prototype) - uf = UNonDiagNoiseJacobianWrapper(f,tspan[2],p,prob.noise_rate_prototype) - jac_noise_config = build_jac_config(sensealg,uf,u0) - end - paramjac_noise_config = build_param_jac_config(sensealg,pf,y,p) - else - if StochasticDiffEq.is_diagonal_noise(prob) - pf = ParamGradientWrapper(f,tspan[2],y) - if isnoisemixing(sensealg) - uf = DiffEqBase.UDerivativeWrapper(f,tspan[2],p) - end + else - pf = ParamNonDiagNoiseGradientWrapper(f,tspan[1],y) - uf = UNonDiagNoiseGradientWrapper(f,tspan[2],p) + paramjac_noise_config = nothing + jac_noise_config = nothing end + else paramjac_noise_config = nothing jac_noise_config = nothing - end - if StochasticDiffEq.is_diagonal_noise(prob) - pJ = similar(u0, numindvar, numparams) - if isnoisemixing(sensealg) - J = similar(u0, numindvar, numindvar) - end - else - pJ = similar(u0, numindvar*numindvar, numparams) - J = similar(u0, numindvar*numindvar, numindvar) - end - - else - paramjac_noise_config = nothing - jac_noise_config = nothing end - else - paramjac_noise_config = nothing - jac_noise_config = nothing - end - - adjoint_cache = AdjointDiffCache(uf,pf,pg,J,pJ,dg_val, - jac_config,pg_config,paramjac_config, - jac_noise_config,paramjac_noise_config, - f_cache,dg,diffvar_idxs,algevar_idxs, - factorized_mass_matrix,issemiexplicitdae) - - return adjoint_cache, y + + adjoint_cache = AdjointDiffCache(uf, pf, pg, J, pJ, dg_val, + jac_config, pg_config, paramjac_config, + jac_noise_config, paramjac_noise_config, + f_cache, dg, diffvar_idxs, algevar_idxs, + factorized_mass_matrix, issemiexplicitdae) + + return adjoint_cache, y end -getprob(S::SensitivityFunction) = (S isa ODEBacksolveSensitivityFunction) ? S.prob : S.sol.prob +function getprob(S::SensitivityFunction) + (S isa ODEBacksolveSensitivityFunction) ? S.prob : S.sol.prob +end inplace_sensitivity(S::SensitivityFunction) = isinplace(getprob(S)) -struct ReverseLossCallback{λType,timeType,yType,RefType,FMType,AlgType,gType,cacheType} - isq::Bool - λ::λType - t::timeType - y::yType - cur_time::RefType - idx::Int - F::FMType - sensealg::AlgType - g::gType - diffcache::cacheType +struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, gType, + cacheType} + isq::Bool + λ::λType + t::timeType + y::yType + cur_time::RefType + idx::Int + F::FMType + sensealg::AlgType + g::gType + diffcache::cacheType end function ReverseLossCallback(sensefun, λ, t, g, cur_time) - @unpack sensealg, y = sensefun - isq = (sensealg isa QuadratureAdjoint) + @unpack sensealg, y = sensefun + isq = (sensealg isa QuadratureAdjoint) - @unpack factorized_mass_matrix = sensefun.diffcache - prob = getprob(sensefun) - idx = length(prob.u0) + @unpack factorized_mass_matrix = sensefun.diffcache + prob = getprob(sensefun) + idx = length(prob.u0) - return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix, sensealg, g, sensefun.diffcache) + return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix, + sensealg, g, sensefun.diffcache) end function (f::ReverseLossCallback)(integrator) - @unpack isq, λ, t, y, cur_time, idx, F, sensealg, g = f - @unpack diffvar_idxs, algevar_idxs, issemiexplicitdae, J, uf, f_cache, jac_config = f.diffcache - - p, u = integrator.p, integrator.u - - if sensealg isa BacksolveAdjoint - copyto!(y,integrator.u[end-idx+1:end]) - end - - # Warning: alias here! Be careful with λ - gᵤ = isq ? λ : @view(λ[1:idx]) - g(gᵤ,y,p,t[cur_time[]],cur_time[]) - - if issemiexplicitdae - jacobian!(J, uf, y, f_cache, sensealg, jac_config) - dhdd = J[algevar_idxs, diffvar_idxs] - dhda = J[algevar_idxs, algevar_idxs] - # TODO: maybe need a `conj` - Δλa = -dhda'\gᵤ[algevar_idxs] - Δλd = dhdd'Δλa + gᵤ[diffvar_idxs] - else - Δλd = gᵤ - end - - if F !== nothing - F !== I && F !== (I,I) && ldiv!(F, Δλd) - end - - u[diffvar_idxs] .+= Δλd - u_modified!(integrator,true) - cur_time[] -= 1 - return nothing + @unpack isq, λ, t, y, cur_time, idx, F, sensealg, g = f + @unpack diffvar_idxs, algevar_idxs, issemiexplicitdae, J, uf, f_cache, jac_config = f.diffcache + + p, u = integrator.p, integrator.u + + if sensealg isa BacksolveAdjoint + copyto!(y, integrator.u[(end - idx + 1):end]) + end + + # Warning: alias here! Be careful with λ + gᵤ = isq ? λ : @view(λ[1:idx]) + g(gᵤ, y, p, t[cur_time[]], cur_time[]) + + if issemiexplicitdae + jacobian!(J, uf, y, f_cache, sensealg, jac_config) + dhdd = J[algevar_idxs, diffvar_idxs] + dhda = J[algevar_idxs, algevar_idxs] + # TODO: maybe need a `conj` + Δλa = -dhda' \ gᵤ[algevar_idxs] + Δλd = dhdd'Δλa + gᵤ[diffvar_idxs] + else + Δλd = gᵤ + end + + if F !== nothing + F !== I && F !== (I, I) && ldiv!(F, Δλd) + end + + u[diffvar_idxs] .+= Δλd + u_modified!(integrator, true) + cur_time[] -= 1 + return nothing end # handle discrete loss contributions -function generate_callbacks(sensefun, dg, λ, t, t0, callback, init_cb, terminated=false) - if sensefun isa NILSASSensitivityFunction - @unpack sensealg = sensefun.S - else - @unpack sensealg = sensefun - end - - if !init_cb - cur_time = Ref(1) - else - cur_time = Ref(length(t)) - end - - reverse_cbs = setup_reverse_callbacks(callback,sensealg,dg,cur_time,terminated) - init_cb || return reverse_cbs, nothing - - # callbacks can lead to non-unique time points - _t, duplicate_iterator_times = separate_nonunique(t) - - rlcb = ReverseLossCallback(sensefun, λ, t, dg, cur_time) - - if eltype(_t) !== typeof(t0) - _t = convert.(typeof(t0),_t) - end - cb = PresetTimeCallback(_t,rlcb) - - # handle duplicates (currently only for double occurances) - if duplicate_iterator_times!==nothing - # use same ref for cur_time to cope with concrete_solve - cbrev_dupl_affect = ReverseLossCallback(sensefun, λ, t, dg, cur_time) - cb_dupl = PresetTimeCallback(duplicate_iterator_times[1],cbrev_dupl_affect) - return CallbackSet(cb,reverse_cbs,cb_dupl), duplicate_iterator_times - else - return CallbackSet(cb,reverse_cbs), duplicate_iterator_times - end -end +function generate_callbacks(sensefun, dg, λ, t, t0, callback, init_cb, terminated = false) + if sensefun isa NILSASSensitivityFunction + @unpack sensealg = sensefun.S + else + @unpack sensealg = sensefun + end + if !init_cb + cur_time = Ref(1) + else + cur_time = Ref(length(t)) + end -function separate_nonunique(t) - # t is already sorted - _t = unique(t) - ts_with_occurances = [(i, count(==(i), t)) for i in _t] + reverse_cbs = setup_reverse_callbacks(callback, sensealg, dg, cur_time, terminated) + init_cb || return reverse_cbs, nothing - # duplicates (only those values which occur > 1 times) - dupl = filter(x->last(x)>1, ts_with_occurances) + # callbacks can lead to non-unique time points + _t, duplicate_iterator_times = separate_nonunique(t) - ts = first.(dupl) - occurances = last.(dupl) + rlcb = ReverseLossCallback(sensefun, λ, t, dg, cur_time) + if eltype(_t) !== typeof(t0) + _t = convert.(typeof(t0), _t) + end + cb = PresetTimeCallback(_t, rlcb) + + # handle duplicates (currently only for double occurances) + if duplicate_iterator_times !== nothing + # use same ref for cur_time to cope with concrete_solve + cbrev_dupl_affect = ReverseLossCallback(sensefun, λ, t, dg, cur_time) + cb_dupl = PresetTimeCallback(duplicate_iterator_times[1], cbrev_dupl_affect) + return CallbackSet(cb, reverse_cbs, cb_dupl), duplicate_iterator_times + else + return CallbackSet(cb, reverse_cbs), duplicate_iterator_times + end +end + +function separate_nonunique(t) + # t is already sorted + _t = unique(t) + ts_with_occurances = [(i, count(==(i), t)) for i in _t] - if isempty(occurances) - itrs = nothing - else - maxoc = maximum(occurances) - maxoc > 2 && error("More than two occurances of the same time point. Please report this.") - # handle also more than two occurances - itrs = [ts[occurances .>= i] for i=2:maxoc] - end + # duplicates (only those values which occur > 1 times) + dupl = filter(x -> last(x) > 1, ts_with_occurances) - return _t, itrs + ts = first.(dupl) + occurances = last.(dupl) + + if isempty(occurances) + itrs = nothing + else + maxoc = maximum(occurances) + maxoc > 2 && + error("More than two occurances of the same time point. Please report this.") + # handle also more than two occurances + itrs = [ts[occurances .>= i] for i in 2:maxoc] + end + + return _t, itrs end function out_and_ts(_ts, duplicate_iterator_times, sol) - if duplicate_iterator_times === nothing - ts = _ts - out = sol(ts) - else - # if callbacks are tracked, there is potentially an event_time that must be considered - # in the loss function but doesn't occur in saveat/t. So we need to add it. - # Note that if it doens't occur in saveat/t we even need to add it twice - # However if the callbacks are not saving in the forward, we don't want to compute a loss - # value for them. This information is given by sol.t/checkpoints. - # Additionally we need to store the left and the right limit, respectively. - duplicate_times = duplicate_iterator_times[1] # just treat two occurances at the moment (see separate_nonunique above) - _ts = Array(_ts) - for d in duplicate_times - (d ∉ _ts) && push!(_ts, d) - end + if duplicate_iterator_times === nothing + ts = _ts + out = sol(ts) + else + # if callbacks are tracked, there is potentially an event_time that must be considered + # in the loss function but doesn't occur in saveat/t. So we need to add it. + # Note that if it doens't occur in saveat/t we even need to add it twice + # However if the callbacks are not saving in the forward, we don't want to compute a loss + # value for them. This information is given by sol.t/checkpoints. + # Additionally we need to store the left and the right limit, respectively. + duplicate_times = duplicate_iterator_times[1] # just treat two occurances at the moment (see separate_nonunique above) + _ts = Array(_ts) + for d in duplicate_times + (d ∉ _ts) && push!(_ts, d) + end - u1 = sol(_ts).u - u2 = sol(duplicate_times,continuity=:right).u - saveat = vcat(_ts, duplicate_times...) - perm = sortperm(saveat) - ts = saveat[perm] - u = vcat(u1, u2)[perm] - out = DiffEqArray(u,ts) - end - return out, ts + u1 = sol(_ts).u + u2 = sol(duplicate_times, continuity = :right).u + saveat = vcat(_ts, duplicate_times...) + perm = sortperm(saveat) + ts = saveat[perm] + u = vcat(u1, u2)[perm] + out = DiffEqArray(u, ts) + end + return out, ts end diff --git a/src/backsolve_adjoint.jl b/src/backsolve_adjoint.jl index 5828d9757..371065c66 100644 --- a/src/backsolve_adjoint.jl +++ b/src/backsolve_adjoint.jl @@ -1,395 +1,424 @@ -struct ODEBacksolveSensitivityFunction{C<:AdjointDiffCache,Alg<:BacksolveAdjoint,uType,pType,fType<:DiffEqBase.AbstractDiffEqFunction} <: SensitivityFunction - diffcache::C - sensealg::Alg - discrete::Bool - y::uType - prob::pType - f::fType - noiseterm::Bool +struct ODEBacksolveSensitivityFunction{C <: AdjointDiffCache, Alg <: BacksolveAdjoint, + uType, pType, + fType <: DiffEqBase.AbstractDiffEqFunction} <: + SensitivityFunction + diffcache::C + sensealg::Alg + discrete::Bool + y::uType + prob::pType + f::fType + noiseterm::Bool end +function ODEBacksolveSensitivityFunction(g, sensealg, discrete, sol, dg, f; + noiseterm = false) + diffcache, y = adjointdiffcache(g, sensealg, discrete, sol, dg, f; quad = false, + noiseterm = noiseterm) -function ODEBacksolveSensitivityFunction(g,sensealg,discrete,sol,dg,f;noiseterm=false) - diffcache, y = adjointdiffcache(g,sensealg,discrete,sol,dg,f;quad=false,noiseterm=noiseterm) - - return ODEBacksolveSensitivityFunction(diffcache,sensealg,discrete, - y,sol.prob,f,noiseterm) + return ODEBacksolveSensitivityFunction(diffcache, sensealg, discrete, + y, sol.prob, f, noiseterm) end -function (S::ODEBacksolveSensitivityFunction)(du,u,p,t) - @unpack y, prob, discrete = S +function (S::ODEBacksolveSensitivityFunction)(du, u, p, t) + @unpack y, prob, discrete = S - λ,grad,_y,dλ,dgrad,dy = split_states(du,u,t,S) + λ, grad, _y, dλ, dgrad, dy = split_states(du, u, t, S) - if eltype(_y) <: ForwardDiff.Dual # handle implicit solvers - copyto!(vec(y), ForwardDiff.value.(_y)) - else - copyto!(vec(y), _y) - end + if eltype(_y) <: ForwardDiff.Dual # handle implicit solvers + copyto!(vec(y), ForwardDiff.value.(_y)) + else + copyto!(vec(y), _y) + end - if S.noiseterm - if length(u) == length(du) - vecjacobian!(dλ, y, λ, p, t, S, dgrad=dgrad, dy=dy) - elseif length(u) != length(du) && StochasticDiffEq.is_diagonal_noise(prob) && !isnoisemixing(S.sensealg) - vecjacobian!(dλ, y, λ, p, t, S, dy=dy) - jacNoise!(λ, y, p, t, S, dgrad=dgrad) + if S.noiseterm + if length(u) == length(du) + vecjacobian!(dλ, y, λ, p, t, S, dgrad = dgrad, dy = dy) + elseif length(u) != length(du) && StochasticDiffEq.is_diagonal_noise(prob) && + !isnoisemixing(S.sensealg) + vecjacobian!(dλ, y, λ, p, t, S, dy = dy) + jacNoise!(λ, y, p, t, S, dgrad = dgrad) + else + jacNoise!(λ, y, p, t, S, dgrad = dgrad, dλ = dλ, dy = dy) + end else - jacNoise!(λ, y, p, t, S, dgrad=dgrad, dλ=dλ, dy=dy) + vecjacobian!(dλ, y, λ, p, t, S, dgrad = dgrad, dy = dy) end - else - vecjacobian!(dλ, y, λ, p, t, S, dgrad=dgrad, dy=dy) - end - dλ .*= -1 - dgrad .*= -one(eltype(dgrad)) - - discrete || accumulate_cost!(dλ, y, p, t, S, dgrad) - return nothing + dλ .*= -1 + dgrad .*= -one(eltype(dgrad)) + + discrete || accumulate_cost!(dλ, y, p, t, S, dgrad) + return nothing end # u = λ' # for the RODE case -function (S::ODEBacksolveSensitivityFunction)(du,u,p,t,W) - @unpack y, prob, discrete = S +function (S::ODEBacksolveSensitivityFunction)(du, u, p, t, W) + @unpack y, prob, discrete = S - λ,grad,_y,dλ,dgrad,dy = split_states(du,u,t,S) - copyto!(vec(y), _y) + λ, grad, _y, dλ, dgrad, dy = split_states(du, u, t, S) + copyto!(vec(y), _y) - vecjacobian!(dλ, y, λ, p, t, S, dgrad=dgrad, dy=dy,W=W) - dλ .*= -one(eltype(λ)) - dgrad .*= -one(eltype(dgrad)) + vecjacobian!(dλ, y, λ, p, t, S, dgrad = dgrad, dy = dy, W = W) + dλ .*= -one(eltype(λ)) + dgrad .*= -one(eltype(dgrad)) - discrete || accumulate_cost!(dλ, y, p, t, S, dgrad) - return nothing + discrete || accumulate_cost!(dλ, y, p, t, S, dgrad) + return nothing end -function split_states(du,u,t,S::ODEBacksolveSensitivityFunction;update=true) - @unpack y, prob = S - idx = length(y) - - λ = @view u[1:idx] - grad = @view u[idx+1:end-idx] - _y = @view u[end-idx+1:end] - - if length(u) == length(du) - # ODE/Drift term and scalar noise - dλ = @view du[1:idx] - dgrad = @view du[idx+1:end-idx] - dy = @view du[end-idx+1:end] - - elseif length(u) != length(du) && StochasticDiffEq.is_diagonal_noise(prob) && !isnoisemixing(S.sensealg) - # Diffusion term, diagonal noise, length(du) = u*m - idx1 = [length(u)*(i-1)+i for i in 1:idx] # for diagonal indices of [1:idx,1:idx] - idx2 = [(length(u)+1)*i-idx for i in 1:idx] # for diagonal indices of [end-idx+1:end,1:idx] - - dλ = @view du[idx1] - dgrad = @view du[idx+1:end-idx,1:idx] - dy = @view du[idx2] - - elseif length(u) != length(du) && StochasticDiffEq.is_diagonal_noise(prob) && isnoisemixing(S.sensealg) - # Diffusion term, diagonal noise, (as above but can handle mixing noise terms) - idx2 = [(length(u)+1)*i-idx for i in 1:idx] # for diagonal indices of [end-idx+1:end,1:idx] - - dλ = @view du[1:idx,1:idx] - dgrad = @view du[idx+1:end-idx,1:idx] - dy = @view du[idx2] - - elseif typeof(du) <: AbstractMatrix - # non-diagonal noise - dλ = @view du[1:idx, 1:idx] - dgrad = @view du[idx+1:end-idx,1:idx] - dy = @view du[end-idx+1:end, 1:idx] - end - λ,grad,_y,dλ,dgrad,dy +function split_states(du, u, t, S::ODEBacksolveSensitivityFunction; update = true) + @unpack y, prob = S + idx = length(y) + + λ = @view u[1:idx] + grad = @view u[(idx + 1):(end - idx)] + _y = @view u[(end - idx + 1):end] + + if length(u) == length(du) + # ODE/Drift term and scalar noise + dλ = @view du[1:idx] + dgrad = @view du[(idx + 1):(end - idx)] + dy = @view du[(end - idx + 1):end] + + elseif length(u) != length(du) && StochasticDiffEq.is_diagonal_noise(prob) && + !isnoisemixing(S.sensealg) + # Diffusion term, diagonal noise, length(du) = u*m + idx1 = [length(u) * (i - 1) + i for i in 1:idx] # for diagonal indices of [1:idx,1:idx] + idx2 = [(length(u) + 1) * i - idx for i in 1:idx] # for diagonal indices of [end-idx+1:end,1:idx] + + dλ = @view du[idx1] + dgrad = @view du[(idx + 1):(end - idx), 1:idx] + dy = @view du[idx2] + + elseif length(u) != length(du) && StochasticDiffEq.is_diagonal_noise(prob) && + isnoisemixing(S.sensealg) + # Diffusion term, diagonal noise, (as above but can handle mixing noise terms) + idx2 = [(length(u) + 1) * i - idx for i in 1:idx] # for diagonal indices of [end-idx+1:end,1:idx] + + dλ = @view du[1:idx, 1:idx] + dgrad = @view du[(idx + 1):(end - idx), 1:idx] + dy = @view du[idx2] + + elseif typeof(du) <: AbstractMatrix + # non-diagonal noise + dλ = @view du[1:idx, 1:idx] + dgrad = @view du[(idx + 1):(end - idx), 1:idx] + dy = @view du[(end - idx + 1):end, 1:idx] + end + λ, grad, _y, dλ, dgrad, dy end # g is either g(t,u,p) or discrete g(t,u,i) -@noinline function ODEAdjointProblem(sol,sensealg::BacksolveAdjoint, - t=nothing, - dg_discrete::DG1=nothing,dg_continuous::DG2=nothing, - g::G=nothing; - checkpoints=sol.t, - callback=CallbackSet(), - z0=nothing, - M=nothing, - nilss=nothing, - tspan=sol.prob.tspan, - kwargs...) where {DG1,DG2,G} - # add homogenous adjoint for NILSAS by explicitly passing a z0 and nilss::NILSSSensitivityFunction - dg_discrete===nothing && dg_continuous===nothing && g===nothing && error("Either `dg_discrete`, `dg_continuous`, or `g` must be specified.") - - @unpack f, p, u0 = sol.prob - - # check if solution was terminated, then use reduced time span - terminated = false - if hasfield(typeof(sol),:retcode) - if sol.retcode == :Terminated - tspan = (tspan[1], sol.t[end]) - terminated = true +@noinline function ODEAdjointProblem(sol, sensealg::BacksolveAdjoint, + t = nothing, + dg_discrete::DG1 = nothing, + dg_continuous::DG2 = nothing, + g::G = nothing; + checkpoints = sol.t, + callback = CallbackSet(), + z0 = nothing, + M = nothing, + nilss = nothing, + tspan = sol.prob.tspan, + kwargs...) where {DG1, DG2, G} + # add homogenous adjoint for NILSAS by explicitly passing a z0 and nilss::NILSSSensitivityFunction + dg_discrete === nothing && dg_continuous === nothing && g === nothing && + error("Either `dg_discrete`, `dg_continuous`, or `g` must be specified.") + + @unpack f, p, u0 = sol.prob + + # check if solution was terminated, then use reduced time span + terminated = false + if hasfield(typeof(sol), :retcode) + if sol.retcode == :Terminated + tspan = (tspan[1], sol.t[end]) + terminated = true + end end - end - tspan = reverse(tspan) + tspan = reverse(tspan) - discrete = (t !== nothing && dg_continuous === nothing) + discrete = (t !== nothing && dg_continuous === nothing) - numstates = length(u0) - numparams = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(p) + numstates = length(u0) + numparams = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(p) - len = length(u0)+numparams + len = length(u0) + numparams - if z0===nothing - λ = p === nothing || p === DiffEqBase.NullParameters() ? similar(u0) : one(eltype(u0)) .* similar(p, len) - λ .= false - else - λ = nothing - end + if z0 === nothing + λ = p === nothing || p === DiffEqBase.NullParameters() ? similar(u0) : + one(eltype(u0)) .* similar(p, len) + λ .= false + else + λ = nothing + end - sense = ODEBacksolveSensitivityFunction(g,sensealg,discrete,sol,dg_continuous,f) + sense = ODEBacksolveSensitivityFunction(g, sensealg, discrete, sol, dg_continuous, f) - if z0!==nothing - sense = NILSASSensitivityFunction{isinplace(f),typeof(nilss),typeof(sense),typeof(M)}(nilss,sense,M,discrete) - end + if z0 !== nothing + sense = NILSASSensitivityFunction{isinplace(f), typeof(nilss), typeof(sense), + typeof(M)}(nilss, sense, M, discrete) + end - init_cb = (discrete || dg_discrete!==nothing) # && tspan[1] == t[end] - cb, duplicate_iterator_times = generate_callbacks(sense, dg_discrete, λ, t, tspan[2], callback, init_cb, terminated) - checkpoints = ischeckpointing(sensealg, sol) ? checkpoints : nothing - if checkpoints !== nothing - cb = backsolve_checkpoint_callbacks(sense, sol, checkpoints, cb, duplicate_iterator_times) - end + init_cb = (discrete || dg_discrete !== nothing) # && tspan[1] == t[end] + cb, duplicate_iterator_times = generate_callbacks(sense, dg_discrete, λ, t, tspan[2], + callback, init_cb, terminated) + checkpoints = ischeckpointing(sensealg, sol) ? checkpoints : nothing + if checkpoints !== nothing + cb = backsolve_checkpoint_callbacks(sense, sol, checkpoints, cb, + duplicate_iterator_times) + end - if z0===nothing - z0 = [vec(zero(λ)); vec(sense.y)] - end - original_mm = sol.prob.f.mass_matrix - zzz(A, m, n) = fill!(similar(A, m, n), zero(eltype(original_mm))) - if original_mm === I || original_mm === (I,I) - mm = I - else - sense.diffcache.issemiexplicitdae && @warn "`BacksolveAdjoint` is likely to fail on semi-explicit DAEs, if memory is a concern, please consider using InterpolatingAdjoint(checkpoint=true) instead." - II = Diagonal(I, numparams) - Z1 = zzz(original_mm, numstates, numstates+numparams) - Z2 = zzz(original_mm, numparams, numstates) - mm = [copy(original_mm') Z1 - Z2 II Z2 - Z1 original_mm] - end - jac_prototype = sol.prob.f.jac_prototype - if !sense.discrete || jac_prototype === nothing - adjoint_jac_prototype = nothing - else - J = jac_prototype - Ja = copy(J') - II = Diagonal(I, numparams) - Z1 = zzz(J, numstates, numstates+numparams) - Z2 = zzz(J, numparams, numstates) - adjoint_jac_prototype = [Ja Z1 - Z2 II Z2 - Z1 J] - end - odefun = ODEFunction(sense, mass_matrix=mm, jac_prototype=adjoint_jac_prototype) - return ODEProblem(odefun,z0,tspan,p,callback=cb) + if z0 === nothing + z0 = [vec(zero(λ)); vec(sense.y)] + end + original_mm = sol.prob.f.mass_matrix + zzz(A, m, n) = fill!(similar(A, m, n), zero(eltype(original_mm))) + if original_mm === I || original_mm === (I, I) + mm = I + else + sense.diffcache.issemiexplicitdae && + @warn "`BacksolveAdjoint` is likely to fail on semi-explicit DAEs, if memory is a concern, please consider using InterpolatingAdjoint(checkpoint=true) instead." + II = Diagonal(I, numparams) + Z1 = zzz(original_mm, numstates, numstates + numparams) + Z2 = zzz(original_mm, numparams, numstates) + mm = [copy(original_mm') Z1 + Z2 II Z2 + Z1 original_mm] + end + jac_prototype = sol.prob.f.jac_prototype + if !sense.discrete || jac_prototype === nothing + adjoint_jac_prototype = nothing + else + J = jac_prototype + Ja = copy(J') + II = Diagonal(I, numparams) + Z1 = zzz(J, numstates, numstates + numparams) + Z2 = zzz(J, numparams, numstates) + adjoint_jac_prototype = [Ja Z1 + Z2 II Z2 + Z1 J] + end + odefun = ODEFunction(sense, mass_matrix = mm, jac_prototype = adjoint_jac_prototype) + return ODEProblem(odefun, z0, tspan, p, callback = cb) end -@noinline function SDEAdjointProblem(sol,sensealg::BacksolveAdjoint, - t=nothing, - dg_discrete::DG1=nothing,dg_continuous::DG2=nothing, - g::G=nothing; - checkpoints=sol.t, - callback=CallbackSet(), - corfunc_analytical=nothing,diffusion_jac=nothing,diffusion_paramjac=nothing, - kwargs...) where {DG1,DG2,G} - - dg_discrete===nothing && dg_continuous===nothing && g===nothing && error("Either `dg_discrete`, `dg_continuous`, or `g` must be specified.") - - @unpack f, p, u0, tspan = sol.prob - # check if solution was terminated, then use reduced time span - terminated = false - if hasfield(typeof(sol),:retcode) - if sol.retcode == :Terminated - tspan = (tspan[1], sol.t[end]) - terminated = true +@noinline function SDEAdjointProblem(sol, sensealg::BacksolveAdjoint, + t = nothing, + dg_discrete::DG1 = nothing, + dg_continuous::DG2 = nothing, + g::G = nothing; + checkpoints = sol.t, + callback = CallbackSet(), + corfunc_analytical = nothing, diffusion_jac = nothing, + diffusion_paramjac = nothing, + kwargs...) where {DG1, DG2, G} + dg_discrete === nothing && dg_continuous === nothing && g === nothing && + error("Either `dg_discrete`, `dg_continuous`, or `g` must be specified.") + + @unpack f, p, u0, tspan = sol.prob + # check if solution was terminated, then use reduced time span + terminated = false + if hasfield(typeof(sol), :retcode) + if sol.retcode == :Terminated + tspan = (tspan[1], sol.t[end]) + terminated = true + end end - end - - tspan = reverse(tspan) - discrete = (t !== nothing && dg_continuous === nothing) - - p === DiffEqBase.NullParameters() && error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!") - - numstates = length(u0) - numparams = length(p) - - len = length(u0)+numparams - λ = one(eltype(u0)) .* similar(p, len) - - if StochasticDiffEq.alg_interpretation(sol.alg) == :Stratonovich - sense_drift = ODEBacksolveSensitivityFunction(g,sensealg,discrete,sol,dg_continuous,sol.prob.f) - else - transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g,corfunc_analytical) - drift_function = ODEFunction(transformed_function) - sense_drift = ODEBacksolveSensitivityFunction(g,sensealg,discrete,sol,dg_continuous,drift_function) - end - - diffusion_function = ODEFunction(sol.prob.g, jac=diffusion_jac, paramjac=diffusion_paramjac) - sense_diffusion = ODEBacksolveSensitivityFunction(g,sensealg,discrete,sol,dg_continuous,diffusion_function;noiseterm=true) - - init_cb = (discrete || dg_discrete!==nothing) # && tspan[1] == t[end] - cb, duplicate_iterator_times = generate_callbacks(sense_drift, dg_discrete, λ, t, tspan[2], callback, init_cb, terminated) - checkpoints = ischeckpointing(sensealg, sol) ? checkpoints : nothing - if checkpoints !== nothing - cb = backsolve_checkpoint_callbacks(sense_drift, sol, checkpoints, cb, duplicate_iterator_times) - end - - z0 = [vec(zero(λ)); vec(sense_drift.y)] - - original_mm = sol.prob.f.mass_matrix - if original_mm === I - mm = I - else - sense_drift.diffcache.issemiexplicitdae && @warn "`BacksolveAdjoint` is likely to fail on semi-explicit DAEs, if memory is a concern, please consider using InterpolatingAdjoint(checkpoint=true) instead." - len2 = length(z0) - mm = zeros(len2, len2) - idx = 1:numstates - copyto!(@view(mm[idx, idx]), sol.prob.f.mass_matrix') - idx = numstates+1:numstates+1+numparams - copyto!(@view(mm[idx, idx]), I) - idx = len+1:len2 - copyto!(@view(mm[idx, idx]), sol.prob.f.mass_matrix) - end - - sdefun = SDEFunction(sense_drift,sense_diffusion,mass_matrix=mm) - - # replicated noise - _sol = deepcopy(sol) - backwardnoise = reverse(_sol.W) - - if StochasticDiffEq.is_diagonal_noise(sol.prob) && typeof(sol.W[end])<:Number - # scalar noise case - noise_matrix = nothing - else - noise_matrix = similar(z0,length(z0),numstates) - noise_matrix .= false - end - - return SDEProblem(sdefun,sense_diffusion,z0,tspan,p, - callback=cb, - noise=backwardnoise, - noise_rate_prototype = noise_matrix - ) -end + tspan = reverse(tspan) + discrete = (t !== nothing && dg_continuous === nothing) + + p === DiffEqBase.NullParameters() && + error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!") -@noinline function RODEAdjointProblem(sol,sensealg::BacksolveAdjoint, - t=nothing, - dg_discrete::DG1=nothing,dg_continuous::DG2=nothing, - g::G=nothing; - checkpoints=sol.t, - callback=CallbackSet(), - kwargs...) where {DG1,DG2,G} + numstates = length(u0) + numparams = length(p) - dg_discrete===nothing && dg_continuous===nothing && g===nothing && error("Either `dg_discrete`, `dg_continuous`, or `g` must be specified.") + len = length(u0) + numparams + λ = one(eltype(u0)) .* similar(p, len) - @unpack f, p, u0, tspan = sol.prob - # check if solution was terminated, then use reduced time span - terminated = false - if hasfield(typeof(sol),:retcode) - if sol.retcode == :Terminated - tspan = (tspan[1], sol.t[end]) - terminated = true + if StochasticDiffEq.alg_interpretation(sol.alg) == :Stratonovich + sense_drift = ODEBacksolveSensitivityFunction(g, sensealg, discrete, sol, + dg_continuous, sol.prob.f) + else + transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g, + corfunc_analytical) + drift_function = ODEFunction(transformed_function) + sense_drift = ODEBacksolveSensitivityFunction(g, sensealg, discrete, sol, + dg_continuous, drift_function) end - end - tspan = reverse(tspan) - discrete = (t !== nothing && dg_continuous === nothing) - - p === DiffEqBase.NullParameters() && error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!") - - numstates = length(u0) - numparams = length(p) - - len = length(u0)+numparams - λ = one(eltype(u0)) .* similar(p, len) - - sense = ODEBacksolveSensitivityFunction(g,sensealg,discrete,sol,dg_continuous,f;noiseterm=false) - - init_cb = (discrete || dg_discrete!==nothing) # && tspan[1] == t[end] - cb, duplicate_iterator_times = generate_callbacks(sense, dg_discrete, λ, t, tspan[2], callback, init_cb, terminated) - checkpoints = ischeckpointing(sensealg, sol) ? checkpoints : nothing - if checkpoints !== nothing - cb = backsolve_checkpoint_callbacks(sense, sol, checkpoints, cb, duplicate_iterator_times) - end - - z0 = [vec(zero(λ)); vec(sense.y)] - - original_mm = sol.prob.f.mass_matrix - if original_mm === I - mm = I - else - sense.diffcache.issemiexplicitdae && @warn "`BacksolveAdjoint` is likely to fail on semi-explicit DAEs, if memory is a concern, please consider using InterpolatingAdjoint(checkpoint=true) instead." - len2 = length(z0) - mm = zeros(len2, len2) - idx = 1:numstates - copyto!(@view(mm[idx, idx]), sol.prob.f.mass_matrix') - idx = numstates+1:numstates+1+numparams - copyto!(@view(mm[idx, idx]), I) - idx = len+1:len2 - copyto!(@view(mm[idx, idx]), sol.prob.f.mass_matrix) - end - - rodefun = RODEFunction(sense,mass_matrix=mm) - - # replicated noise - _sol = deepcopy(sol) - backwardnoise = reverse(_sol.W) - - return RODEProblem(rodefun,z0,tspan,p, - callback=cb, - noise=backwardnoise - ) -end + diffusion_function = ODEFunction(sol.prob.g, jac = diffusion_jac, + paramjac = diffusion_paramjac) + sense_diffusion = ODEBacksolveSensitivityFunction(g, sensealg, discrete, sol, + dg_continuous, diffusion_function; + noiseterm = true) + + init_cb = (discrete || dg_discrete !== nothing) # && tspan[1] == t[end] + cb, duplicate_iterator_times = generate_callbacks(sense_drift, dg_discrete, λ, t, + tspan[2], callback, init_cb, + terminated) + checkpoints = ischeckpointing(sensealg, sol) ? checkpoints : nothing + if checkpoints !== nothing + cb = backsolve_checkpoint_callbacks(sense_drift, sol, checkpoints, cb, + duplicate_iterator_times) + end + + z0 = [vec(zero(λ)); vec(sense_drift.y)] -function backsolve_checkpoint_callbacks(sensefun, sol, checkpoints, callback, duplicate_iterator_times=nothing) - prob = sol.prob - if duplicate_iterator_times !== nothing - _checkpoints = filter(x->x ∉ duplicate_iterator_times[1], checkpoints) - else - _checkpoints = checkpoints - end - cur_time = Ref(length(_checkpoints)) - affect! = let sol=sol, cur_time=cur_time, idx=length(prob.u0) - function (integrator) - _y = reshape(@view(integrator.u[end-idx+1:end]), axes(prob.u0)) - sol(_y, integrator.t) - u_modified!(integrator,true) - cur_time[] -= 1 - return nothing + original_mm = sol.prob.f.mass_matrix + if original_mm === I + mm = I + else + sense_drift.diffcache.issemiexplicitdae && + @warn "`BacksolveAdjoint` is likely to fail on semi-explicit DAEs, if memory is a concern, please consider using InterpolatingAdjoint(checkpoint=true) instead." + len2 = length(z0) + mm = zeros(len2, len2) + idx = 1:numstates + copyto!(@view(mm[idx, idx]), sol.prob.f.mass_matrix') + idx = (numstates + 1):(numstates + 1 + numparams) + copyto!(@view(mm[idx, idx]), I) + idx = (len + 1):len2 + copyto!(@view(mm[idx, idx]), sol.prob.f.mass_matrix) end - end + sdefun = SDEFunction(sense_drift, sense_diffusion, mass_matrix = mm) - cb = PresetTimeCallback(_checkpoints,affect!) - return CallbackSet(cb,callback) + # replicated noise + _sol = deepcopy(sol) + backwardnoise = reverse(_sol.W) + + if StochasticDiffEq.is_diagonal_noise(sol.prob) && typeof(sol.W[end]) <: Number + # scalar noise case + noise_matrix = nothing + else + noise_matrix = similar(z0, length(z0), numstates) + noise_matrix .= false + end + + return SDEProblem(sdefun, sense_diffusion, z0, tspan, p, + callback = cb, + noise = backwardnoise, + noise_rate_prototype = noise_matrix) end +@noinline function RODEAdjointProblem(sol, sensealg::BacksolveAdjoint, + t = nothing, + dg_discrete::DG1 = nothing, + dg_continuous::DG2 = nothing, + g::G = nothing; + checkpoints = sol.t, + callback = CallbackSet(), + kwargs...) where {DG1, DG2, G} + dg_discrete === nothing && dg_continuous === nothing && g === nothing && + error("Either `dg_discrete`, `dg_continuous`, or `g` must be specified.") + + @unpack f, p, u0, tspan = sol.prob + # check if solution was terminated, then use reduced time span + terminated = false + if hasfield(typeof(sol), :retcode) + if sol.retcode == :Terminated + tspan = (tspan[1], sol.t[end]) + terminated = true + end + end + tspan = reverse(tspan) + discrete = (t !== nothing && dg_continuous === nothing) + + p === DiffEqBase.NullParameters() && + error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!") + + numstates = length(u0) + numparams = length(p) -function backsolve_checkpoint_callbacks(sensefun::NILSASSensitivityFunction, sol, checkpoints, callback, duplicate_iterator_times=nothing) - prob = sol.prob - if duplicate_iterator_times !== nothing - _checkpoints = filter(x->x ∉ duplicate_iterator_times[1], checkpoints) - else - _checkpoints = checkpoints - end - cur_time = Ref(length(_checkpoints)) - affect! = let sol=sol, cur_time=cur_time - function (integrator) - _y = integrator.u.x[3] - sol(_y, integrator.t) - u_modified!(integrator,true) - cur_time[] -= 1 - return nothing + len = length(u0) + numparams + λ = one(eltype(u0)) .* similar(p, len) + + sense = ODEBacksolveSensitivityFunction(g, sensealg, discrete, sol, dg_continuous, f; + noiseterm = false) + + init_cb = (discrete || dg_discrete !== nothing) # && tspan[1] == t[end] + cb, duplicate_iterator_times = generate_callbacks(sense, dg_discrete, λ, t, tspan[2], + callback, init_cb, terminated) + checkpoints = ischeckpointing(sensealg, sol) ? checkpoints : nothing + if checkpoints !== nothing + cb = backsolve_checkpoint_callbacks(sense, sol, checkpoints, cb, + duplicate_iterator_times) + end + + z0 = [vec(zero(λ)); vec(sense.y)] + + original_mm = sol.prob.f.mass_matrix + if original_mm === I + mm = I + else + sense.diffcache.issemiexplicitdae && + @warn "`BacksolveAdjoint` is likely to fail on semi-explicit DAEs, if memory is a concern, please consider using InterpolatingAdjoint(checkpoint=true) instead." + len2 = length(z0) + mm = zeros(len2, len2) + idx = 1:numstates + copyto!(@view(mm[idx, idx]), sol.prob.f.mass_matrix') + idx = (numstates + 1):(numstates + 1 + numparams) + copyto!(@view(mm[idx, idx]), I) + idx = (len + 1):len2 + copyto!(@view(mm[idx, idx]), sol.prob.f.mass_matrix) end - end + rodefun = RODEFunction(sense, mass_matrix = mm) + + # replicated noise + _sol = deepcopy(sol) + backwardnoise = reverse(_sol.W) + + return RODEProblem(rodefun, z0, tspan, p, + callback = cb, + noise = backwardnoise) +end + +function backsolve_checkpoint_callbacks(sensefun, sol, checkpoints, callback, + duplicate_iterator_times = nothing) + prob = sol.prob + if duplicate_iterator_times !== nothing + _checkpoints = filter(x -> x ∉ duplicate_iterator_times[1], checkpoints) + else + _checkpoints = checkpoints + end + cur_time = Ref(length(_checkpoints)) + affect! = let sol = sol, cur_time = cur_time, idx = length(prob.u0) + function (integrator) + _y = reshape(@view(integrator.u[(end - idx + 1):end]), axes(prob.u0)) + sol(_y, integrator.t) + u_modified!(integrator, true) + cur_time[] -= 1 + return nothing + end + end + + cb = PresetTimeCallback(_checkpoints, affect!) + return CallbackSet(cb, callback) +end + +function backsolve_checkpoint_callbacks(sensefun::NILSASSensitivityFunction, sol, + checkpoints, callback, + duplicate_iterator_times = nothing) + prob = sol.prob + if duplicate_iterator_times !== nothing + _checkpoints = filter(x -> x ∉ duplicate_iterator_times[1], checkpoints) + else + _checkpoints = checkpoints + end + cur_time = Ref(length(_checkpoints)) + affect! = let sol = sol, cur_time = cur_time + function (integrator) + _y = integrator.u.x[3] + sol(_y, integrator.t) + u_modified!(integrator, true) + cur_time[] -= 1 + return nothing + end + end - cb = PresetTimeCallback(_checkpoints,affect!) - return CallbackSet(cb,callback) + cb = PresetTimeCallback(_checkpoints, affect!) + return CallbackSet(cb, callback) end diff --git a/src/callback_tracking.jl b/src/callback_tracking.jl index 319d673dd..6107b68b4 100644 --- a/src/callback_tracking.jl +++ b/src/callback_tracking.jl @@ -4,54 +4,55 @@ the reverse pass. The rationale is explain in: https://github.com/SciML/SciMLSensitivity.jl/issues/4 """ -track_callbacks(cb,t,u,p,sensealg) = track_callbacks(CallbackSet(cb),t,u,p,sensealg) -track_callbacks(cb::CallbackSet,t,u,p,sensealg) = CallbackSet( - map(cb->_track_callback(cb,t,u,p,sensealg), cb.continuous_callbacks), - map(cb->_track_callback(cb,t,u,p,sensealg), cb.discrete_callbacks)) - -mutable struct ImplicitCorrection{T1,T2,T3,T4,T5,T6,T7,T8,T9,T10,T11,RefType} - gt_val::T1 - gu_val::T2 - gt::T3 - gu::T4 - gt_conf::T5 - gu_conf::T6 - condition::T7 - Lu_left::T8 - Lu_right::T9 - dy_left::T10 - dy_right::T11 - cur_time::RefType # initialized as "dummy" Ref that gets overwritten by Ref of loss - terminated::Bool +track_callbacks(cb, t, u, p, sensealg) = track_callbacks(CallbackSet(cb), t, u, p, sensealg) +function track_callbacks(cb::CallbackSet, t, u, p, sensealg) + CallbackSet(map(cb -> _track_callback(cb, t, u, p, sensealg), cb.continuous_callbacks), + map(cb -> _track_callback(cb, t, u, p, sensealg), cb.discrete_callbacks)) end -ImplicitCorrection(cb::DiscreteCallback,t,u,p,sensealg) = nothing -function ImplicitCorrection(cb,t,u,p,sensealg) - condition = cb.condition +mutable struct ImplicitCorrection{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, RefType} + gt_val::T1 + gu_val::T2 + gt::T3 + gu::T4 + gt_conf::T5 + gu_conf::T6 + condition::T7 + Lu_left::T8 + Lu_right::T9 + dy_left::T10 + dy_right::T11 + cur_time::RefType # initialized as "dummy" Ref that gets overwritten by Ref of loss + terminated::Bool +end - gt_val = similar(u,1) - gu_val = similar(u) +ImplicitCorrection(cb::DiscreteCallback, t, u, p, sensealg) = nothing +function ImplicitCorrection(cb, t, u, p, sensealg) + condition = cb.condition - fakeinteg = FakeIntegrator(u,p,t,t) - gt, gu = build_condition_wrappers(cb,condition,u,t,fakeinteg) + gt_val = similar(u, 1) + gu_val = similar(u) - gt_conf = build_deriv_config(sensealg,gt,gt_val,t) - gu_conf = build_grad_config(sensealg,gu,u,p) + fakeinteg = FakeIntegrator(u, p, t, t) + gt, gu = build_condition_wrappers(cb, condition, u, t, fakeinteg) + gt_conf = build_deriv_config(sensealg, gt, gt_val, t) + gu_conf = build_grad_config(sensealg, gu, u, p) - dy_left = similar(u) - dy_right = similar(u) + dy_left = similar(u) + dy_right = similar(u) - Lu_left = similar(u) - Lu_right = similar(u) + Lu_left = similar(u) + Lu_right = similar(u) - cur_time = Ref(1) # initialize the Ref, set to Ref of loss below - terminated = false + cur_time = Ref(1) # initialize the Ref, set to Ref of loss below + terminated = false - ImplicitCorrection(gt_val,gu_val,gt,gu,gt_conf,gu_conf,condition,Lu_left,Lu_right,dy_left,dy_right,cur_time,terminated) + ImplicitCorrection(gt_val, gu_val, gt, gu, gt_conf, gu_conf, condition, Lu_left, + Lu_right, dy_left, dy_right, cur_time, terminated) end -struct TrackedAffect{T,T2,T3,T4,T5,T6} +struct TrackedAffect{T, T2, T3, T4, T5, T6} event_times::Vector{T} tprev::Vector{T} uleft::Vector{T2} @@ -61,85 +62,87 @@ struct TrackedAffect{T,T2,T3,T4,T5,T6} event_idx::Vector{T6} end -TrackedAffect(t::Number,u,p,affect!::Nothing,correction) = nothing -TrackedAffect(t::Number,u,p,affect!,correction) = TrackedAffect(Vector{typeof(t)}(undef,0),Vector{typeof(t)}(undef,0), - Vector{typeof(u)}(undef,0),Vector{typeof(p)}(undef,0),affect!,correction, - Vector{Int}(undef,0)) +TrackedAffect(t::Number, u, p, affect!::Nothing, correction) = nothing +function TrackedAffect(t::Number, u, p, affect!, correction) + TrackedAffect(Vector{typeof(t)}(undef, 0), Vector{typeof(t)}(undef, 0), + Vector{typeof(u)}(undef, 0), Vector{typeof(p)}(undef, 0), affect!, + correction, + Vector{Int}(undef, 0)) +end -function (f::TrackedAffect)(integrator,event_idx=nothing) +function (f::TrackedAffect)(integrator, event_idx = nothing) uleft = deepcopy(integrator.u) pleft = deepcopy(integrator.p) - if event_idx===nothing - f.affect!(integrator) + if event_idx === nothing + f.affect!(integrator) else - f.affect!(integrator,event_idx) + f.affect!(integrator, event_idx) end if integrator.u_modified if isempty(f.event_times) - push!(f.event_times,integrator.t) - push!(f.tprev,integrator.tprev) - push!(f.uleft,uleft) - push!(f.pleft,pleft) - if event_idx !== nothing - push!(f.event_idx,event_idx) - end - else - if !maximum(.≈(integrator.t, f.event_times, rtol=0.0, atol=1e-14)) - push!(f.event_times,integrator.t) - push!(f.tprev,integrator.tprev) - push!(f.uleft,uleft) - push!(f.pleft,pleft) + push!(f.event_times, integrator.t) + push!(f.tprev, integrator.tprev) + push!(f.uleft, uleft) + push!(f.pleft, pleft) if event_idx !== nothing - push!(f.event_idx, event_idx) + push!(f.event_idx, event_idx) + end + else + if !maximum(.≈(integrator.t, f.event_times, rtol = 0.0, atol = 1e-14)) + push!(f.event_times, integrator.t) + push!(f.tprev, integrator.tprev) + push!(f.uleft, uleft) + push!(f.pleft, pleft) + if event_idx !== nothing + push!(f.event_idx, event_idx) + end end - end end end end -function _track_callback(cb::DiscreteCallback,t,u,p,sensealg) - correction = ImplicitCorrection(cb,t,u,p,sensealg) +function _track_callback(cb::DiscreteCallback, t, u, p, sensealg) + correction = ImplicitCorrection(cb, t, u, p, sensealg) DiscreteCallback(cb.condition, - TrackedAffect(t,u,p,cb.affect!,correction), + TrackedAffect(t, u, p, cb.affect!, correction), cb.initialize, cb.finalize, cb.save_positions) end -function _track_callback(cb::ContinuousCallback,t,u,p,sensealg) - correction = ImplicitCorrection(cb,t,u,p,sensealg) - ContinuousCallback( - cb.condition, - TrackedAffect(t,u,p,cb.affect!,correction), - TrackedAffect(t,u,p,cb.affect_neg!,correction), - cb.initialize, - cb.finalize, - cb.idxs, - cb.rootfind,cb.interp_points, - cb.save_positions, - cb.dtrelax,cb.abstol,cb.reltol,cb.repeat_nudge) -end - -function _track_callback(cb::VectorContinuousCallback,t,u,p,sensealg) - correction = ImplicitCorrection(cb,t,u,p,sensealg) - VectorContinuousCallback( - cb.condition, - TrackedAffect(t,u,p,cb.affect!,correction), - TrackedAffect(t,u,p,cb.affect_neg!,correction), - cb.len,cb.initialize,cb.finalize,cb.idxs, - cb.rootfind,cb.interp_points, - collect(cb.save_positions), - cb.dtrelax,cb.abstol,cb.reltol,cb.repeat_nudge) -end - -struct FakeIntegrator{uType,P,tType,tprevType} +function _track_callback(cb::ContinuousCallback, t, u, p, sensealg) + correction = ImplicitCorrection(cb, t, u, p, sensealg) + ContinuousCallback(cb.condition, + TrackedAffect(t, u, p, cb.affect!, correction), + TrackedAffect(t, u, p, cb.affect_neg!, correction), + cb.initialize, + cb.finalize, + cb.idxs, + cb.rootfind, cb.interp_points, + cb.save_positions, + cb.dtrelax, cb.abstol, cb.reltol, cb.repeat_nudge) +end + +function _track_callback(cb::VectorContinuousCallback, t, u, p, sensealg) + correction = ImplicitCorrection(cb, t, u, p, sensealg) + VectorContinuousCallback(cb.condition, + TrackedAffect(t, u, p, cb.affect!, correction), + TrackedAffect(t, u, p, cb.affect_neg!, correction), + cb.len, cb.initialize, cb.finalize, cb.idxs, + cb.rootfind, cb.interp_points, + collect(cb.save_positions), + cb.dtrelax, cb.abstol, cb.reltol, cb.repeat_nudge) +end + +struct FakeIntegrator{uType, P, tType, tprevType} u::uType p::P t::tType tprev::tprevType end -struct CallbackSensitivityFunction{fType,Alg<:DiffEqBase.AbstractSensitivityAlgorithm,C<:AdjointDiffCache,pType} <: SensitivityFunction +struct CallbackSensitivityFunction{fType, Alg <: DiffEqBase.AbstractSensitivityAlgorithm, + C <: AdjointDiffCache, pType} <: SensitivityFunction f::fType sensealg::Alg diffcache::C @@ -156,18 +159,23 @@ vjps as described in https://arxiv.org/pdf/1905.10403.pdf Equation 13. For more information, see https://github.com/SciML/SciMLSensitivity.jl/issues/4 """ -setup_reverse_callbacks(cb,sensealg,g,cur_time,terminated) = setup_reverse_callbacks(CallbackSet(cb),sensealg,g,cur_time,terminated) -function setup_reverse_callbacks(cb::CallbackSet,sensealg,g,cur_time,terminated) - cb = CallbackSet(_setup_reverse_callbacks.(cb.continuous_callbacks,(sensealg,),(g,),(cur_time,),(terminated,))..., - reverse(_setup_reverse_callbacks.(cb.discrete_callbacks,(sensealg,),(g,),(cur_time,),(terminated,)))...) +function setup_reverse_callbacks(cb, sensealg, g, cur_time, terminated) + setup_reverse_callbacks(CallbackSet(cb), sensealg, g, cur_time, terminated) +end +function setup_reverse_callbacks(cb::CallbackSet, sensealg, g, cur_time, terminated) + cb = CallbackSet(_setup_reverse_callbacks.(cb.continuous_callbacks, (sensealg,), (g,), + (cur_time,), (terminated,))..., + reverse(_setup_reverse_callbacks.(cb.discrete_callbacks, (sensealg,), + (g,), (cur_time,), (terminated,)))...) return cb end -function _setup_reverse_callbacks(cb::Union{ContinuousCallback,DiscreteCallback,VectorContinuousCallback},sensealg,g,loss_ref,terminated) - - if cb isa Union{ContinuousCallback,VectorContinuousCallback} && cb.affect! !== nothing - cb.affect!.correction.cur_time = loss_ref # set cur_time - cb.affect!.correction.terminated = terminated # flag if time evolution was terminated by callback +function _setup_reverse_callbacks(cb::Union{ContinuousCallback, DiscreteCallback, + VectorContinuousCallback}, sensealg, g, + loss_ref, terminated) + if cb isa Union{ContinuousCallback, VectorContinuousCallback} && cb.affect! !== nothing + cb.affect!.correction.cur_time = loss_ref # set cur_time + cb.affect!.correction.terminated = terminated # flag if time evolution was terminated by callback end # ReverseLossCallback adds gradients before and after the callback if save_positions is (true, true). @@ -177,332 +185,343 @@ function _setup_reverse_callbacks(cb::Union{ContinuousCallback,DiscreteCallback, # if save_positions = [0,1] the gradient contribution is added before the callback but no additional gradient is added afterwards. # if save_positions = [1,0] the gradient contribution is added before, and in principle we would need to correct the adjoint state again. Thefore, - cb.save_positions == [1,0] && error("save_positions=[1,0] is currently not supported.") + cb.save_positions == [1, 0] && error("save_positions=[1,0] is currently not supported.") function affect!(integrator) - - indx, pos_neg = get_indx(cb,integrator.t) - tprev = get_tprev(cb,indx,pos_neg) - event_idx = cb isa VectorContinuousCallback ? get_event_idx(cb,indx,pos_neg) : nothing - - w = let tprev=tprev, pos_neg=pos_neg, event_idx=event_idx - function (du,u,p,t) - _affect! = get_affect!(cb,pos_neg) - fakeinteg = FakeIntegrator([x for x in u],[x for x in p],t,tprev) + indx, pos_neg = get_indx(cb, integrator.t) + tprev = get_tprev(cb, indx, pos_neg) + event_idx = cb isa VectorContinuousCallback ? get_event_idx(cb, indx, pos_neg) : + nothing + + w = let tprev = tprev, pos_neg = pos_neg, event_idx = event_idx + function (du, u, p, t) + _affect! = get_affect!(cb, pos_neg) + fakeinteg = FakeIntegrator([x for x in u], [x for x in p], t, tprev) if cb isa VectorContinuousCallback - _affect!(fakeinteg,event_idx) + _affect!(fakeinteg, event_idx) else - _affect!(fakeinteg) + _affect!(fakeinteg) end du .= fakeinteg.u - end + end end S = integrator.f.f # get the sensitivity function # Create a fake sensitivity function to do the vjps - fakeS = CallbackSensitivityFunction(w,sensealg,S.diffcache,integrator.sol.prob) + fakeS = CallbackSensitivityFunction(w, sensealg, S.diffcache, integrator.sol.prob) du = first(get_tmp_cache(integrator)) - λ,grad,y,dλ,dgrad,dy = split_states(du,integrator.u,integrator.t,S) + λ, grad, y, dλ, dgrad, dy = split_states(du, integrator.u, integrator.t, S) # if save_positions[2] = false, then the right limit is not saved. Thus, for # the QuadratureAdjoint we would need to lift y from the left to the right limit. # However, one also needs to update dgrad later on. if (sensealg isa QuadratureAdjoint && !cb.save_positions[2]) # || (sensealg isa InterpolatingAdjoint && ischeckpointing(sensealg)) - # lifting for InterpolatingAdjoint is not needed anymore. Callback is already applied. - w(y,y,integrator.p,integrator.t) + # lifting for InterpolatingAdjoint is not needed anymore. Callback is already applied. + w(y, y, integrator.p, integrator.t) end - if cb isa Union{ContinuousCallback,VectorContinuousCallback} - # correction of the loss function sensitivity for continuous callbacks - # wrt dependence of event time t on parameters and initial state. - # Must be handled here because otherwise it is unclear if continuous or - # discrete callback was triggered. - @unpack correction = cb.affect! - @unpack dy_right, Lu_right = correction - # compute #f(xτ_right,p_right,τ(x₀,p)) - compute_f!(dy_right,S,y,integrator) - # if callback did not terminate the time evolution, we have to compute one more correction term. - if cb.save_positions[2] && !correction.terminated - loss_indx = correction.cur_time[] + 1 - loss_correction!(Lu_right,y,integrator,g,loss_indx) - else - Lu_right .*= false - end + if cb isa Union{ContinuousCallback, VectorContinuousCallback} + # correction of the loss function sensitivity for continuous callbacks + # wrt dependence of event time t on parameters and initial state. + # Must be handled here because otherwise it is unclear if continuous or + # discrete callback was triggered. + @unpack correction = cb.affect! + @unpack dy_right, Lu_right = correction + # compute #f(xτ_right,p_right,τ(x₀,p)) + compute_f!(dy_right, S, y, integrator) + # if callback did not terminate the time evolution, we have to compute one more correction term. + if cb.save_positions[2] && !correction.terminated + loss_indx = correction.cur_time[] + 1 + loss_correction!(Lu_right, y, integrator, g, loss_indx) + else + Lu_right .*= false + end end - update_p = copy_to_integrator!(cb,y,integrator.p,integrator.t,indx,pos_neg) + update_p = copy_to_integrator!(cb, y, integrator.p, integrator.t, indx, pos_neg) # reshape u and du (y and dy) to match forward pass (e.g., for matrices as initial conditions). Only needed for BacksolveAdjoint if sensealg isa BacksolveAdjoint - _size = pos_neg ? size(cb.affect!.uleft[indx]) : size(cb.affect_neg!.uleft[indx]) - y = reshape(y, _size) - dy = reshape(dy, _size) + _size = pos_neg ? size(cb.affect!.uleft[indx]) : + size(cb.affect_neg!.uleft[indx]) + y = reshape(y, _size) + dy = reshape(dy, _size) end - if cb isa Union{ContinuousCallback,VectorContinuousCallback} - # compute the correction of the right limit (with left state limit inserted into dgdt) - @unpack dy_left, cur_time = correction - compute_f!(dy_left,S,y,integrator) - dgdt(dy_left,correction,sensealg,y,integrator,tprev,event_idx) - if !correction.terminated - implicit_correction!(Lu_right,dλ,λ,dy_right,correction) - correction.terminated = false # additional callbacks might have happened which didn't terminate the time evolution - end + if cb isa Union{ContinuousCallback, VectorContinuousCallback} + # compute the correction of the right limit (with left state limit inserted into dgdt) + @unpack dy_left, cur_time = correction + compute_f!(dy_left, S, y, integrator) + dgdt(dy_left, correction, sensealg, y, integrator, tprev, event_idx) + if !correction.terminated + implicit_correction!(Lu_right, dλ, λ, dy_right, correction) + correction.terminated = false # additional callbacks might have happened which didn't terminate the time evolution + end end if update_p - # changes in parameters - if !(sensealg isa QuadratureAdjoint) - - wp = let tprev=tprev, pos_neg=pos_neg, event_idx=event_idx - function (dp,p,u,t) - _affect! = get_affect!(cb,pos_neg) - fakeinteg = FakeIntegrator([x for x in u],[x for x in p],t,tprev) - if cb isa VectorContinuousCallback - _affect!(fakeinteg, event_idx) - else - _affect!(fakeinteg) + # changes in parameters + if !(sensealg isa QuadratureAdjoint) + wp = let tprev = tprev, pos_neg = pos_neg, event_idx = event_idx + function (dp, p, u, t) + _affect! = get_affect!(cb, pos_neg) + fakeinteg = FakeIntegrator([x for x in u], [x for x in p], t, tprev) + if cb isa VectorContinuousCallback + _affect!(fakeinteg, event_idx) + else + _affect!(fakeinteg) + end + dp .= fakeinteg.p + end end - dp .= fakeinteg.p - end + fakeSp = CallbackSensitivityFunction(wp, sensealg, S.diffcache, + integrator.sol.prob) + #vjp with Jacobin given by dw/dp before event and vector given by grad + vecjacobian!(dgrad, integrator.p, grad, y, integrator.t, fakeSp; + dgrad = nothing, dy = nothing) + grad .= dgrad end - fakeSp = CallbackSensitivityFunction(wp,sensealg,S.diffcache,integrator.sol.prob) - #vjp with Jacobin given by dw/dp before event and vector given by grad - vecjacobian!(dgrad, integrator.p, grad, y, integrator.t, fakeSp; - dgrad=nothing, dy=nothing) - grad .= dgrad - end end - vecjacobian!(dλ, y, λ, integrator.p, integrator.t, fakeS; - dgrad=dgrad, dy=dy) - - dgrad!==nothing && (dgrad .*= -1) - if cb isa Union{ContinuousCallback,VectorContinuousCallback} - # second correction to correct for left limit - @unpack Lu_left = correction - implicit_correction!(Lu_left,dλ,dy_left,correction) - dλ .+= Lu_left - Lu_right - - if cb.save_positions[1] == true - # if the callback saved the first position, we need to implicitly correct this value as well - loss_indx = correction.cur_time[] - implicit_correction!(Lu_left,dy_left,correction,y,integrator,g,loss_indx) - dλ .+= Lu_left - end + dgrad = dgrad, dy = dy) + + dgrad !== nothing && (dgrad .*= -1) + if cb isa Union{ContinuousCallback, VectorContinuousCallback} + # second correction to correct for left limit + @unpack Lu_left = correction + implicit_correction!(Lu_left, dλ, dy_left, correction) + dλ .+= Lu_left - Lu_right + + if cb.save_positions[1] == true + # if the callback saved the first position, we need to implicitly correct this value as well + loss_indx = correction.cur_time[] + implicit_correction!(Lu_left, dy_left, correction, y, integrator, g, + loss_indx) + dλ .+= Lu_left + end end λ .= dλ if !(sensealg isa QuadratureAdjoint) - grad .-= dgrad + grad .-= dgrad end - end times = if typeof(cb) <: DiscreteCallback cb.affect!.event_times else - [cb.affect!.event_times;cb.affect_neg!.event_times] + [cb.affect!.event_times; cb.affect_neg!.event_times] end PresetTimeCallback(times, affect!, - save_positions = (false,false)) -end - -get_indx(cb::DiscreteCallback,t) = (searchsortedfirst(cb.affect!.event_times,t), true) -function get_indx(cb::Union{ContinuousCallback,VectorContinuousCallback}, t) - if !isempty(cb.affect!.event_times) || !isempty(cb.affect_neg!.event_times) - indx = searchsortedfirst(cb.affect!.event_times,t) - indx_neg = searchsortedfirst(cb.affect_neg!.event_times,t) - if !isempty(cb.affect!.event_times) && cb.affect!.event_times[min(indx,length(cb.affect!.event_times))]==t - return indx, true - elseif !isempty(cb.affect_neg!.event_times) && cb.affect_neg!.event_times[min(indx_neg,length(cb.affect_neg!.event_times))]==t - return indx_neg, false + save_positions = (false, false)) +end + +get_indx(cb::DiscreteCallback, t) = (searchsortedfirst(cb.affect!.event_times, t), true) +function get_indx(cb::Union{ContinuousCallback, VectorContinuousCallback}, t) + if !isempty(cb.affect!.event_times) || !isempty(cb.affect_neg!.event_times) + indx = searchsortedfirst(cb.affect!.event_times, t) + indx_neg = searchsortedfirst(cb.affect_neg!.event_times, t) + if !isempty(cb.affect!.event_times) && + cb.affect!.event_times[min(indx, length(cb.affect!.event_times))] == t + return indx, true + elseif !isempty(cb.affect_neg!.event_times) && + cb.affect_neg!.event_times[min(indx_neg, length(cb.affect_neg!.event_times))] == + t + return indx_neg, false + else + error("Event was triggered but no corresponding event in ContinuousCallback was found. Please report this error.") + end else - error("Event was triggered but no corresponding event in ContinuousCallback was found. Please report this error.") + error("No event was recorded. Please report this error.") end - else - error("No event was recorded. Please report this error.") - end end -get_tprev(cb::DiscreteCallback,indx,bool) = cb.affect!.tprev[indx] -function get_tprev(cb::Union{ContinuousCallback,VectorContinuousCallback}, indx, bool) - if bool - return cb.affect!.tprev[indx] - else - return cb.affect_neg!.tprev[indx] - end +get_tprev(cb::DiscreteCallback, indx, bool) = cb.affect!.tprev[indx] +function get_tprev(cb::Union{ContinuousCallback, VectorContinuousCallback}, indx, bool) + if bool + return cb.affect!.tprev[indx] + else + return cb.affect_neg!.tprev[indx] + end end function get_event_idx(cb::VectorContinuousCallback, indx, bool) - if bool - return cb.affect!.event_idx[indx] - else - return cb.affect_neg!.event_idx[indx] - end + if bool + return cb.affect!.event_idx[indx] + else + return cb.affect_neg!.event_idx[indx] + end end function copy_to_integrator!(cb::DiscreteCallback, y, p, t, indx, bool) - copyto!(y, cb.affect!.uleft[indx]) - update_p = (p != cb.affect!.pleft[indx]) - update_p && copyto!(p, cb.affect!.pleft[indx]) - update_p -end - -function copy_to_integrator!(cb::Union{ContinuousCallback,VectorContinuousCallback}, y, p, t, indx, bool) - if bool copyto!(y, cb.affect!.uleft[indx]) update_p = (p != cb.affect!.pleft[indx]) update_p && copyto!(p, cb.affect!.pleft[indx]) - else - copyto!(y, cb.affect_neg!.uleft[indx]) - update_p = (p != cb.affect_neg!.pleft[indx]) - update_p && copyto!(p, cb.affect_neg!.pleft[indx]) - end - update_p + update_p +end + +function copy_to_integrator!(cb::Union{ContinuousCallback, VectorContinuousCallback}, y, p, + t, indx, bool) + if bool + copyto!(y, cb.affect!.uleft[indx]) + update_p = (p != cb.affect!.pleft[indx]) + update_p && copyto!(p, cb.affect!.pleft[indx]) + else + copyto!(y, cb.affect_neg!.uleft[indx]) + update_p = (p != cb.affect_neg!.pleft[indx]) + update_p && copyto!(p, cb.affect_neg!.pleft[indx]) + end + update_p end -function compute_f!(dy,S,y,integrator) - p, t = integrator.p, integrator.t +function compute_f!(dy, S, y, integrator) + p, t = integrator.p, integrator.t - if inplace_sensitivity(S) - S.f(dy,y,p,t) - else - dy[:] .= S.f(y,p,t) - end - return nothing + if inplace_sensitivity(S) + S.f(dy, y, p, t) + else + dy[:] .= S.f(y, p, t) + end + return nothing end -function dgdt(dy,correction,sensealg,y,integrator,tprev,event_idx) - # dy refers to f evaluated on left limit - @unpack gt_val, gu_val, gt, gu, gt_conf, gu_conf, condition = correction +function dgdt(dy, correction, sensealg, y, integrator, tprev, event_idx) + # dy refers to f evaluated on left limit + @unpack gt_val, gu_val, gt, gu, gt_conf, gu_conf, condition = correction - p, t = integrator.p, integrator.t + p, t = integrator.p, integrator.t - fakeinteg = FakeIntegrator([x for x in y],p,t,tprev) + fakeinteg = FakeIntegrator([x for x in y], p, t, tprev) - # derivative and gradient of condition with respect to time and state, respectively - gt.u = y - gt.integrator = fakeinteg + # derivative and gradient of condition with respect to time and state, respectively + gt.u = y + gt.integrator = fakeinteg - gu.t = t - gu.integrator = fakeinteg + gu.t = t + gu.integrator = fakeinteg - # for VectorContinuousCallback we also need to set the event_idx. - if gt isa VectorConditionTimeWrapper - gt.event_idx = event_idx - gu.event_idx = event_idx + # for VectorContinuousCallback we also need to set the event_idx. + if gt isa VectorConditionTimeWrapper + gt.event_idx = event_idx + gu.event_idx = event_idx - # safety check: evaluate condition to check if several conditions were true. - # This is currently not supported - condition(gt.out_cache,y,t,integrator) - gt.out_cache .= abs.(gt.out_cache) .< 1000*eps(eltype(gt.out_cache)) - (sum(gt.out_cache)!=1 || gt.out_cache[event_idx]!=1) && error("Either several events were triggered or `event_idx` was falsely identified. Output of conditions $(gt.out_cache)") - end + # safety check: evaluate condition to check if several conditions were true. + # This is currently not supported + condition(gt.out_cache, y, t, integrator) + gt.out_cache .= abs.(gt.out_cache) .< 1000 * eps(eltype(gt.out_cache)) + (sum(gt.out_cache) != 1 || gt.out_cache[event_idx] != 1) && + error("Either several events were triggered or `event_idx` was falsely identified. Output of conditions $(gt.out_cache)") + end - derivative!(gt_val, gt, t, sensealg, gt_conf) - gradient!(gu_val, gu, y, sensealg, gu_conf) + derivative!(gt_val, gt, t, sensealg, gt_conf) + gradient!(gu_val, gu, y, sensealg, gu_conf) - gt_val .+= dot(gu_val,dy) - @. gt_val = inv(gt_val) # allocates? + gt_val .+= dot(gu_val, dy) + @. gt_val = inv(gt_val) # allocates? - @. gu_val *= -gt_val - return nothing + @. gu_val *= -gt_val + return nothing end -function loss_correction!(Lu,y,integrator,g,indx) - # ∂L∂t correction should be added if L depends explicitly on time. - p, t = integrator.p, integrator.t - g(Lu,y,p,t,indx) - return nothing +function loss_correction!(Lu, y, integrator, g, indx) + # ∂L∂t correction should be added if L depends explicitly on time. + p, t = integrator.p, integrator.t + g(Lu, y, p, t, indx) + return nothing end -function implicit_correction!(Lu,dλ,λ,dy,correction) - @unpack gu_val = correction +function implicit_correction!(Lu, dλ, λ, dy, correction) + @unpack gu_val = correction - # remove gradients from adjoint state to compute correction factor - @. dλ = λ - Lu - Lu .= dot(dλ,dy)*gu_val + # remove gradients from adjoint state to compute correction factor + @. dλ = λ - Lu + Lu .= dot(dλ, dy) * gu_val - return nothing + return nothing end -function implicit_correction!(Lu,λ,dy,correction) - @unpack gu_val = correction +function implicit_correction!(Lu, λ, dy, correction) + @unpack gu_val = correction - Lu .= dot(λ,dy)*gu_val + Lu .= dot(λ, dy) * gu_val - return nothing + return nothing end -function implicit_correction!(Lu,dy,correction,y,integrator,g,indx) - @unpack gu_val = correction +function implicit_correction!(Lu, dy, correction, y, integrator, g, indx) + @unpack gu_val = correction - p, t = integrator.p, integrator.t + p, t = integrator.p, integrator.t - # loss function gradient (not condition!) - # ∂L∂t correction should be added, also ∂L∂p is missing. - # correct adjoint - g(Lu,y,p,t,indx) + # loss function gradient (not condition!) + # ∂L∂t correction should be added, also ∂L∂p is missing. + # correct adjoint + g(Lu, y, p, t, indx) - Lu .= dot(Lu,dy)*gu_val + Lu .= dot(Lu, dy) * gu_val - # note that we don't add the gradient Lu here again to the correction because it will be added by the ReverseLossCallback. - return nothing + # note that we don't add the gradient Lu here again to the correction because it will be added by the ReverseLossCallback. + return nothing end # ConditionTimeWrapper: Wrapper for implicit correction for ContinuousCallback # VectorConditionTimeWrapper: Wrapper for implicit correction for VectorContinuousCallback -function build_condition_wrappers(cb::ContinuousCallback,condition,u,t,fakeinteg) - gt = ConditionTimeWrapper(condition,u,fakeinteg) - gu = ConditionUWrapper(condition,t,fakeinteg) - return gt, gu -end -function build_condition_wrappers(cb::VectorContinuousCallback,condition,u,t,fakeinteg) - out = similar(u, cb.len) # create a cache for condition function (out,u,t,integrator) - gt = VectorConditionTimeWrapper(condition,u,fakeinteg,1,out) - gu = VectorConditionUWrapper(condition,t,fakeinteg,1,out) - return gt, gu -end -mutable struct ConditionTimeWrapper{F,uType,Integrator} <: Function - f::F - u::uType - integrator::Integrator -end -(ff::ConditionTimeWrapper)(t) = [ff.f(ff.u,t,ff.integrator)] -mutable struct ConditionUWrapper{F,tType,Integrator} <: Function - f::F - t::tType - integrator::Integrator -end -(ff::ConditionUWrapper)(u) = ff.f(u,ff.t,ff.integrator) -mutable struct VectorConditionTimeWrapper{F,uType,Integrator,outType} <: Function - f::F - u::uType - integrator::Integrator - event_idx::Int - out_cache::outType -end -(ff::VectorConditionTimeWrapper)(t) = (ff.f(ff.out_cache,ff.u,t,ff.integrator); [ff.out_cache[ff.event_idx]]) - -mutable struct VectorConditionUWrapper{F,tType,Integrator,outType} <: Function - f::F - t::tType - integrator::Integrator - event_idx::Int - out_cache::outType -end -(ff::VectorConditionUWrapper)(u) = (out = similar(u,length(ff.out_cache)); ff.f(out,u,ff.t,ff.integrator); out[ff.event_idx]) +function build_condition_wrappers(cb::ContinuousCallback, condition, u, t, fakeinteg) + gt = ConditionTimeWrapper(condition, u, fakeinteg) + gu = ConditionUWrapper(condition, t, fakeinteg) + return gt, gu +end +function build_condition_wrappers(cb::VectorContinuousCallback, condition, u, t, fakeinteg) + out = similar(u, cb.len) # create a cache for condition function (out,u,t,integrator) + gt = VectorConditionTimeWrapper(condition, u, fakeinteg, 1, out) + gu = VectorConditionUWrapper(condition, t, fakeinteg, 1, out) + return gt, gu +end +mutable struct ConditionTimeWrapper{F, uType, Integrator} <: Function + f::F + u::uType + integrator::Integrator +end +(ff::ConditionTimeWrapper)(t) = [ff.f(ff.u, t, ff.integrator)] +mutable struct ConditionUWrapper{F, tType, Integrator} <: Function + f::F + t::tType + integrator::Integrator +end +(ff::ConditionUWrapper)(u) = ff.f(u, ff.t, ff.integrator) +mutable struct VectorConditionTimeWrapper{F, uType, Integrator, outType} <: Function + f::F + u::uType + integrator::Integrator + event_idx::Int + out_cache::outType +end +function (ff::VectorConditionTimeWrapper)(t) + (ff.f(ff.out_cache, ff.u, t, ff.integrator); [ff.out_cache[ff.event_idx]]) +end + +mutable struct VectorConditionUWrapper{F, tType, Integrator, outType} <: Function + f::F + t::tType + integrator::Integrator + event_idx::Int + out_cache::outType +end +function (ff::VectorConditionUWrapper)(u) + (out = similar(u, length(ff.out_cache)); ff.f(out, u, ff.t, ff.integrator); out[ff.event_idx]) +end DiffEqBase.terminate!(i::FakeIntegrator) = nothing # get the affect function of the callback. For example, allows us to get the `f` in PeriodicCallback without the integrator.tstops handling. -get_affect!(cb::DiscreteCallback,bool) = get_affect!(cb.affect!) -get_affect!(cb::Union{ContinuousCallback,VectorContinuousCallback},bool) = bool ? get_affect!(cb.affect!) : get_affect!(cb.affect_neg!) +get_affect!(cb::DiscreteCallback, bool) = get_affect!(cb.affect!) +function get_affect!(cb::Union{ContinuousCallback, VectorContinuousCallback}, bool) + bool ? get_affect!(cb.affect!) : get_affect!(cb.affect_neg!) +end get_affect!(affect!::TrackedAffect) = get_affect!(affect!.affect!) get_affect!(affect!) = affect! get_affect!(affect!::DiffEqCallbacks.PeriodicCallbackAffect) = affect!.affect! diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 0f730bf64..5574d807f 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1,955 +1,1064 @@ -## High level - -# Here is where we can add a default algorithm for computing sensitivities -# Based on problem information! - -function inplace_vjp(prob,u0,p,verbose) - du = copy(u0) - ez = try - Enzyme.autodiff(Enzyme.Duplicated(du, du), - copy(u0),copy(p),prob.tspan[1]) do out,u,_p,t - prob.f(out, u, _p, t) - nothing - end - true - catch - false - end - if ez - return EnzymeVJP() - end - - # Determine if we can compile ReverseDiff - compile = try - if DiffEqBase.isinplace(prob) - !hasbranching(prob.f,copy(u0),u0,p,prob.tspan[1]) - else - !hasbranching(prob.f,u0,p,prob.tspan[1]) - end - catch - false - end - - vjp = try - ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u,p,t - du1 = similar(u, size(u)) - prob.f(du1,u,p,first(t)) - return vec(du1) - end - ReverseDiffVJP(compile) - catch - false - end - return vjp -end - -function automatic_sensealg_choice(prob::Union{ODEProblem,SDEProblem},u0,p,verbose) - - default_sensealg = if p !== DiffEqBase.NullParameters() && - !(eltype(u0) <: ForwardDiff.Dual) && - !(eltype(p) <: ForwardDiff.Dual) && - !(eltype(u0) <: Complex) && - !(eltype(p) <: Complex) && - length(u0) + length(p) <= 100 - ForwardDiffSensitivity() - elseif u0 isa GPUArraysCore.AbstractGPUArray || !DiffEqBase.isinplace(prob) - # only Zygote is GPU compatible and fast - # so if out-of-place, try Zygote - if p === nothing || p === DiffEqBase.NullParameters() - # QuadratureAdjoint skips all p calculations until the end - # So it's the fastest when there are no parameters - QuadratureAdjoint(autojacvec=ZygoteVJP()) - else - InterpolatingAdjoint(autojacvec=ZygoteVJP()) - end - else - vjp = inplace_vjp(prob,u0,p,verbose) - if p === nothing || p === DiffEqBase.NullParameters() - QuadratureAdjoint(autojacvec=vjp) - else - InterpolatingAdjoint(autojacvec=vjp) - end - end - return default_sensealg -end - -function automatic_sensealg_choice(prob::Union{NonlinearProblem,SteadyStateProblem}, u0, p, verbose) - - default_sensealg = if u0 isa GPUArraysCore.AbstractGPUArray || !DiffEqBase.isinplace(prob) - # autodiff = false because forwarddiff fails on many GPU kernels - # this only effects the Jacobian calculation and is same computation order - SteadyStateAdjoint(autodiff=false, autojacvec=ZygoteVJP()) - else - vjp = inplace_vjp(prob,u0,p,verbose) - SteadyStateAdjoint(autojacvec=vjp) - end - return default_sensealg -end - -function DiffEqBase._concrete_solve_adjoint(prob::Union{ODEProblem,SDEProblem}, - alg,sensealg::Nothing,u0,p,originator::SciMLBase.ADOriginator,args...; - verbose=true,kwargs...) - - if haskey(kwargs,:callback) - has_cb = kwargs[:callback]!==nothing - else - has_cb = false - end - default_sensealg = automatic_sensealg_choice(prob,u0,p,verbose) - if has_cb && typeof(default_sensealg) <: AbstractAdjointSensitivityAlgorithm - default_sensealg = setvjp(default_sensealg, ReverseDiffVJP()) - end - DiffEqBase._concrete_solve_adjoint(prob,alg,default_sensealg,u0,p,originator::SciMLBase.ADOriginator,args...;verbose,kwargs...) -end - -function DiffEqBase._concrete_solve_adjoint(prob::Union{NonlinearProblem,SteadyStateProblem},alg, - sensealg::Nothing,u0,p,originator::SciMLBase.ADOriginator,args...; - verbose=true,kwargs...) - - default_sensealg = automatic_sensealg_choice(prob, u0, p, verbose) - DiffEqBase._concrete_solve_adjoint(prob,alg,default_sensealg,u0,p,originator::SciMLBase.ADOriginator,args...;verbose,kwargs...) -end - -function DiffEqBase._concrete_solve_adjoint(prob::Union{DiscreteProblem,DDEProblem, - SDDEProblem,DAEProblem}, - alg,sensealg::Nothing, - u0,p,originator::SciMLBase.ADOriginator,args...;kwargs...) - if length(u0) + length(p) > 100 - default_sensealg = ReverseDiffAdjoint() - else - default_sensealg = ForwardDiffSensitivity() - end - DiffEqBase._concrete_solve_adjoint(prob,alg,default_sensealg,u0,p,originator::SciMLBase.ADOriginator,args...;kwargs...) -end - -function DiffEqBase._concrete_solve_adjoint(prob,alg, - sensealg::AbstractAdjointSensitivityAlgorithm, - u0,p,originator::SciMLBase.ADOriginator,args...;save_start=true,save_end=true, - saveat = eltype(prob.tspan)[], - save_idxs = nothing, - kwargs...) - - if !(typeof(p) <: Union{Nothing,SciMLBase.NullParameters,AbstractArray}) || (p isa AbstractArray && !Base.isconcretetype(eltype(p))) - throw(AdjointSensitivityParameterCompatibilityError()) - end - - # Remove saveat, etc. from kwargs since it's handled separately - # and letting it jump back in there can break the adjoint - kwargs_prob = NamedTuple(filter(x->x[1] != :saveat && x[1] != :save_start && x[1] != :save_end && x[1] != :save_idxs,prob.kwargs)) - - if haskey(kwargs, :callback) - cb = track_callbacks(CallbackSet(kwargs[:callback]),prob.tspan[1],prob.u0,prob.p,sensealg) - _prob = remake(prob;u0=u0,p=p,kwargs = merge(kwargs_prob,(;callback=cb))) - else - cb = nothing - _prob = remake(prob;u0=u0,p=p,kwargs = kwargs_prob) - end - - # Remove callbacks, saveat, etc. from kwargs since it's handled separately - kwargs_fwd = NamedTuple{Base.diff_names(Base._nt_names( - values(kwargs)), (:callback,))}(values(kwargs)) - - # Capture the callback_adj for the reverse pass and remove both callbacks - kwargs_adj = NamedTuple{Base.diff_names(Base._nt_names(values(kwargs)), (:callback_adj,:callback))}(values(kwargs)) - isq = sensealg isa QuadratureAdjoint - if typeof(sensealg) <: BacksolveAdjoint - sol = solve(_prob,alg,args...;save_noise=true, - save_start=save_start,save_end=save_end, - saveat=saveat,kwargs_fwd...) - elseif ischeckpointing(sensealg) - sol = solve(_prob,alg,args...;save_noise=true, - save_start=true,save_end=true, - saveat=saveat,kwargs_fwd...) - else - sol = solve(_prob,alg,args...;save_noise=true,save_start=true, - save_end=true,kwargs_fwd...) - end - - # Force `save_start` and `save_end` in the forward pass This forces the - # solver to do the backsolve all the way back to `u0` Since the start aliases - # `_prob.u0`, this doesn't actually use more memory But it cleans up the - # implementation and makes `save_start` and `save_end` arg safe. - if typeof(sensealg) <: BacksolveAdjoint - # Saving behavior unchanged - ts = sol.t - only_end = length(ts) == 1 && ts[1] == _prob.tspan[2] - out = DiffEqBase.sensitivity_solution(sol,sol.u,ts) - elseif saveat isa Number - if _prob.tspan[2] > _prob.tspan[1] - ts = _prob.tspan[1]:convert(typeof(_prob.tspan[2]),abs(saveat)):_prob.tspan[2] - else - ts = _prob.tspan[2]:convert(typeof(_prob.tspan[2]),abs(saveat)):_prob.tspan[1] - end - # if _prob.tspan[2]-_prob.tspan[1] is not a multiple of saveat, one looses the last ts value - sol.t[end] !== ts[end] && (ts = fix_endpoints(sensealg,sol,ts)) - if cb === nothing - _out = sol(ts) - else - _, duplicate_iterator_times = separate_nonunique(sol.t) - _out, ts = out_and_ts(ts, duplicate_iterator_times, sol) - end - - out = if save_idxs === nothing - out = DiffEqBase.sensitivity_solution(sol,_out.u,ts) - else - out = DiffEqBase.sensitivity_solution(sol,[_out[i][save_idxs] for i in 1:length(_out)],ts) - end - only_end = length(ts) == 1 && ts[1] == _prob.tspan[2] - elseif isempty(saveat) - no_start = !save_start - no_end = !save_end - sol_idxs = 1:length(sol) - no_start && (sol_idxs = sol_idxs[2:end]) - no_end && (sol_idxs = sol_idxs[1:end-1]) - only_end = length(sol_idxs) <= 1 - _u = sol.u[sol_idxs] - u = save_idxs === nothing ? _u : [x[save_idxs] for x in _u] - ts = sol.t[sol_idxs] - out = DiffEqBase.sensitivity_solution(sol,u,ts) - else - _saveat = saveat isa Array ? sort(saveat) : saveat # for minibatching - if cb === nothing - _saveat = eltype(_saveat) <: typeof(prob.tspan[2]) ? convert.(typeof(_prob.tspan[2]),_saveat) : _saveat - ts = _saveat - _out = sol(ts) - else - _ts, duplicate_iterator_times = separate_nonunique(sol.t) - _out, ts = out_and_ts(_saveat, duplicate_iterator_times, sol) - end - - out = if save_idxs === nothing - out = DiffEqBase.sensitivity_solution(sol,_out.u,ts) - else - out = DiffEqBase.sensitivity_solution(sol,[_out[i][save_idxs] for i in 1:length(_out)],ts) - end - only_end = length(ts) == 1 && ts[1] == _prob.tspan[2] - end - - _save_idxs = save_idxs === nothing ? Colon() : save_idxs - - function adjoint_sensitivity_backpass(Δ) - function df(_out, u, p, t, i) - outtype = typeof(_out) <: SubArray ? DiffEqBase.parameterless_type(_out.parent) : DiffEqBase.parameterless_type(_out) - if only_end - eltype(Δ) <: NoTangent && return - if typeof(Δ) <: AbstractArray{<:AbstractArray} && length(Δ) == 1 && i == 1 - # user did sol[end] on only_end - if typeof(_save_idxs) <: Number - x = vec(Δ[1]) - _out[_save_idxs] .= adapt(outtype,@view(x[_save_idxs])) - elseif _save_idxs isa Colon - vec(_out) .= adapt(outtype,vec(Δ[1])) - else - vec(@view(_out[_save_idxs])) .= adapt(outtype,vec(Δ[1])[_save_idxs]) - end - else - Δ isa NoTangent && return - if typeof(_save_idxs) <: Number - x = vec(Δ) - _out[_save_idxs] .= adapt(outtype,@view(x[_save_idxs])) - elseif _save_idxs isa Colon - vec(_out) .= adapt(outtype,vec(Δ)) - else - x = vec(Δ) - vec(@view(_out[_save_idxs])) .= adapt(outtype,@view(x[_save_idxs])) - end - end - else - !Base.isconcretetype(eltype(Δ)) && (Δ[i] isa NoTangent || eltype(Δ) <: NoTangent) && return - if typeof(Δ) <: AbstractArray{<:AbstractArray} || typeof(Δ) <: DESolution - x = Δ[i] - if typeof(_save_idxs) <: Number - _out[_save_idxs] = @view(x[_save_idxs]) - elseif _save_idxs isa Colon - vec(_out) .= vec(x) - else - vec(@view(_out[_save_idxs])) .= vec(@view(x[_save_idxs])) - end - else - if typeof(_save_idxs) <: Number - _out[_save_idxs] = adapt(outtype,reshape(Δ, prod(size(Δ)[1:end-1]), size(Δ)[end])[_save_idxs, i]) - elseif _save_idxs isa Colon - vec(_out) .= vec(adapt(outtype,reshape(Δ, prod(size(Δ)[1:end-1]), size(Δ)[end])[:, i])) - else - vec(@view(_out[_save_idxs])) .= vec(adapt(outtype,reshape(Δ, prod(size(Δ)[1:end-1]), size(Δ)[end])[:, i])) - end - end - end - end - - if haskey(kwargs_adj, :callback_adj) - cb2 = CallbackSet(cb,kwargs[:callback_adj]) - else - cb2 = cb - end - - du0, dp = adjoint_sensitivities(sol,alg,args...; t=ts, dg_discrete=df, sensealg=sensealg, - callback = cb2, - kwargs_adj...) - - du0 = reshape(du0,size(u0)) - dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : reshape(dp',size(p)) - - if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(),NoTangent(),du0,dp,NoTangent(),ntuple(_->NoTangent(), length(args))...) - else - (NoTangent(),NoTangent(),NoTangent(),du0,dp,NoTangent(),ntuple(_->NoTangent(), length(args))...) - end - end - out, adjoint_sensitivity_backpass -end - -# Prefer this route since it works better with callback AD -function DiffEqBase._concrete_solve_adjoint(prob, alg, sensealg::AbstractForwardSensitivityAlgorithm, - u0, p, originator::SciMLBase.ADOriginator, args...; - save_idxs=nothing, - kwargs...) - - if !(typeof(p) <: Union{Nothing,SciMLBase.NullParameters,AbstractArray}) || (p isa AbstractArray && !Base.isconcretetype(eltype(p))) - throw(ForwardSensitivityParameterCompatibilityError()) - end - - if p isa AbstractArray && eltype(p) <: ForwardDiff.Dual && !(eltype(u0) <: ForwardDiff.Dual) - # Handle double differentiation case - u0 = eltype(p).(u0) - end - _prob = ODEForwardSensitivityProblem(prob.f, u0, prob.tspan, p, sensealg) - sol = solve(_prob, alg, args...; kwargs...) - _, du = extract_local_sensitivities(sol, sensealg, Val(true)) - - u = if save_idxs === nothing - [reshape(sol[i][1:length(u0)], size(u0)) for i in 1:length(sol)] - else - [sol[i][_save_idxs] for i in 1:length(sol)] - end - out = DiffEqBase.sensitivity_solution(sol, u, sol.t) - - function forward_sensitivity_backpass(Δ) - adj = sum(eachindex(du)) do i - J = du[i] - if Δ isa AbstractVector || Δ isa DESolution || Δ isa AbstractVectorOfArray - v = Δ[i] - elseif Δ isa AbstractMatrix - v = @view Δ[:, i] - else - v = @view Δ[.., i] - end - J'vec(v) - end - - du0 = @not_implemented( - "ForwardSensitivity does not differentiate with respect to u0. Change your sensealg." - ) - - if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(), NoTangent(), du0, adj, NoTangent(), ntuple(_ -> NoTangent(), length(args))...) - else - (NoTangent(), NoTangent(), NoTangent(), du0, adj, NoTangent(), ntuple(_ -> NoTangent(), length(args))...) - end - end - out, forward_sensitivity_backpass -end - -function DiffEqBase._concrete_solve_forward(prob,alg, - sensealg::AbstractForwardSensitivityAlgorithm, - u0,p,originator::SciMLBase.ADOriginator,args...;save_idxs = nothing, - kwargs...) - _prob = ODEForwardSensitivityProblem(prob.f,u0,prob.tspan,p,sensealg) - sol = solve(_prob,args...;kwargs...) - u,du = extract_local_sensitivities(sol,Val(true)) - _save_idxs = save_idxs === nothing ? (1:length(u0)) : save_idxs - out = DiffEqBase.sensitivity_solution(sol,[ForwardDiff.value.(sol[i][_save_idxs]) for i in 1:length(sol)],sol.t) - function _concrete_solve_pushforward(Δself, ::Nothing, ::Nothing, x3, Δp, args...) - x3 !== nothing && error("Pushforward currently requires no u0 derivatives") - du * Δp - end - out,_concrete_solve_pushforward -end - -const FORWARDDIFF_SENSITIVITY_PARAMETER_COMPATABILITY_MESSAGE = -""" -ForwardDiffSensitivity assumes the `AbstractArray` interface for `p`. Thus while -DifferentialEquations.jl can support any parameter struct type, usage -with ForwardDiffSensitivity requires that `p` could be a valid -type for being the initial condition `u0` of an array. This means that -many simple types, such as `Tuple`s and `NamedTuple`s, will work as -parameters in normal contexts but will fail during ForwardDiffSensitivity -construction. To work around this issue for complicated cases like nested structs, -look into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl -or ComponentArrays.jl. -""" - -struct ForwardDiffSensitivityParameterCompatibilityError <: Exception end - -function Base.showerror(io::IO, e::ForwardDiffSensitivityParameterCompatibilityError) - print(io, FORWARDDIFF_SENSITIVITY_PARAMETER_COMPATABILITY_MESSAGE) -end - -# Generic Fallback for ForwardDiff -function DiffEqBase._concrete_solve_adjoint(prob,alg, - sensealg::ForwardDiffSensitivity{CS,CTS}, - u0,p,originator::SciMLBase.ADOriginator,args...;saveat=eltype(prob.tspan)[], - kwargs...) where {CS,CTS} - - if !(typeof(p) <: Union{Nothing,SciMLBase.NullParameters,AbstractArray}) || (p isa AbstractArray && !Base.isconcretetype(eltype(p))) - throw(ForwardDiffSensitivityParameterCompatibilityError()) - end - - if saveat isa Number - _saveat = prob.tspan[1]:saveat:prob.tspan[2] - else - _saveat = saveat - end - - sol = solve(remake(prob,p=p,u0=u0),alg,args...;saveat=_saveat, kwargs...) - - # saveat values - # seems overcomplicated, but see the PR - if length(sol.t) == 1 - ts = sol.t - else - ts = eltype(sol.t)[] - if sol.t[2] != sol.t[1] - push!(ts,sol.t[1]) - end - for i in 2:length(sol.t)-1 - if sol.t[i] != sol.t[i+1] && sol.t[i] != sol.t[i-1] - push!(ts,sol.t[i]) - end - end - if sol.t[end] != sol.t[end-1] - push!(ts,sol.t[end]) - end - end - - function forward_sensitivity_backpass(Δ) - dp = @thunk begin - - chunk_size = if CS === 0 && length(p) < 12 - length(p) - elseif CS !== 0 - CS - else - 12 - end - - num_chunks = length(p) ÷ chunk_size - num_chunks * chunk_size != length(p) && (num_chunks += 1) - - pparts = typeof(p[1:1])[] - for j in 0:(num_chunks-1) - - local chunk - if ((j+1)*chunk_size) <= length(p) - chunk = ((j*chunk_size+1) : ((j+1)*chunk_size)) - pchunk = vec(p)[chunk] - pdualpart = seed_duals(pchunk,prob.f,ForwardDiff.Chunk{chunk_size}()) - else - chunk = ((j*chunk_size+1) : length(p)) - pchunk = vec(p)[chunk] - pdualpart = seed_duals(pchunk,prob.f,ForwardDiff.Chunk{length(chunk)}()) - end - - pdualvec = if j == 0 - vcat(pdualpart,p[(j+1)*chunk_size+1 : end]) - elseif j == num_chunks-1 - vcat(p[1:j*chunk_size],pdualpart) - else - vcat(p[1:j*chunk_size],pdualpart,p[((j+1)*chunk_size)+1 : end]) - end - - pdual = ArrayInterfaceCore.restructure(p,pdualvec) - u0dual = convert.(eltype(pdualvec),u0) - - if (convert_tspan(sensealg) === nothing && ( - (haskey(kwargs,:callback) && has_continuous_callback(kwargs[:callback])))) || - (convert_tspan(sensealg) !== nothing && convert_tspan(sensealg)) - - tspandual = convert.(eltype(pdual),prob.tspan) - else - tspandual = prob.tspan - end - - if typeof(prob.f) <: ODEFunction && prob.f.jac_prototype !== nothing - _f = ODEFunction{SciMLBase.isinplace(prob.f),true}(prob.f,jac_prototype = convert.(eltype(u0dual),prob.f.jac_prototype)) - elseif typeof(prob.f) <: SDEFunction && prob.f.jac_prototype !== nothing - _f = SDEFunction{SciMLBase.isinplace(prob.f),true}(prob.f,jac_prototype = convert.(eltype(u0dual),prob.f.jac_prototype)) - else - _f = prob.f - end - _prob = remake(prob,f=_f,u0=u0dual,p=pdual,tspan=tspandual) - - if _prob isa SDEProblem - _prob.noise_rate_prototype!==nothing && (_prob = remake(_prob, noise_rate_prototype = convert.(eltype(pdual), _prob.noise_rate_prototype))) - end - - if saveat isa Number - _saveat = prob.tspan[1]:saveat:prob.tspan[2] - else - _saveat = saveat - end - - _sol = solve(_prob,alg,args...;saveat=ts, kwargs...) - _,du = extract_local_sensitivities(_sol, sensealg, Val(true)) - - _dp = sum(eachindex(du)) do i - J = du[i] - if Δ isa AbstractVector || Δ isa DESolution || Δ isa AbstractVectorOfArray - v = Δ[i] - elseif Δ isa AbstractMatrix - v = @view Δ[:, i] - else - v = @view Δ[.., i] - end - if !(Δ isa NoTangent) - ForwardDiff.value.(J'vec(v)) - else - zero(p) - end - end - push!(pparts,vec(_dp)) - end - ArrayInterfaceCore.restructure(p,reduce(vcat,pparts)) - end - - du0 = @thunk begin - - chunk_size = if CS === 0 && length(u0) < 12 - length(u0) - elseif CS !== 0 - CS - else - 12 - end - - num_chunks = length(u0) ÷ chunk_size - num_chunks * chunk_size != length(u0) && (num_chunks += 1) - - du0parts = typeof(u0[1:1])[] - for j in 0:(num_chunks-1) - - local chunk - if ((j+1)*chunk_size) <= length(u0) - chunk = ((j*chunk_size+1) : ((j+1)*chunk_size)) - u0chunk = vec(u0)[chunk] - u0dualpart = seed_duals(u0chunk,prob.f,ForwardDiff.Chunk{chunk_size}()) - else - chunk = ((j*chunk_size+1) : length(u0)) - u0chunk = vec(u0)[chunk] - u0dualpart = seed_duals(u0chunk,prob.f,ForwardDiff.Chunk{length(chunk)}()) - end - - u0dualvec = if j == 0 - vcat(u0dualpart,u0[(j+1)*chunk_size+1 : end]) - elseif j == num_chunks-1 - vcat(u0[1:j*chunk_size],u0dualpart) - else - vcat(u0[1:j*chunk_size],u0dualpart,u0[((j+1)*chunk_size)+1 : end]) - end - - u0dual = ArrayInterfaceCore.restructure(u0,u0dualvec) - pdual = convert.(eltype(u0dual),p) - - if (convert_tspan(sensealg) === nothing && ( - (haskey(kwargs,:callback) && has_continuous_callback(kwargs[:callback])))) || - (convert_tspan(sensealg) !== nothing && convert_tspan(sensealg)) - - tspandual = convert.(eltype(pdual),prob.tspan) - else - tspandual = prob.tspan - end - - if typeof(prob.f) <: ODEFunction && prob.f.jac_prototype !== nothing - _f = ODEFunction{SciMLBase.isinplace(prob.f),true}(prob.f,jac_prototype = convert.(eltype(pdual),prob.f.jac_prototype)) - elseif typeof(prob.f) <: SDEFunction && prob.f.jac_prototype !== nothing - _f = SDEFunction{SciMLBase.isinplace(prob.f),true}(prob.f,jac_prototype = convert.(eltype(pdual),prob.f.jac_prototype)) - else - _f = prob.f - end - _prob = remake(prob,f=_f,u0=u0dual,p=pdual,tspan=tspandual) - - if _prob isa SDEProblem - _prob.noise_rate_prototype!==nothing && (_prob = remake(_prob, noise_rate_prototype = convert.(eltype(pdual), _prob.noise_rate_prototype))) - end - - if saveat isa Number - _saveat = prob.tspan[1]:saveat:prob.tspan[2] - else - _saveat = saveat - end - - _sol = solve(_prob,alg,args...;saveat=ts, kwargs...) - _,du = extract_local_sensitivities(_sol, sensealg, Val(true)) - - _du0 = sum(eachindex(du)) do i - J = du[i] - if Δ isa AbstractVector || Δ isa DESolution || Δ isa AbstractVectorOfArray - v = Δ[i] - elseif Δ isa AbstractMatrix - v = @view Δ[:, i] - else - v = @view Δ[.., i] - end - if !(Δ isa NoTangent) - ForwardDiff.value.(J'vec(v)) - else - zero(u0) - end - end - push!(du0parts,vec(_du0)) - end - ArrayInterfaceCore.restructure(u0,reduce(vcat,du0parts)) - end - - if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(),NoTangent(),unthunk(du0),unthunk(dp),NoTangent(),ntuple(_->NoTangent(), length(args))...) - else - (NoTangent(),NoTangent(),NoTangent(),du0,dp,NoTangent(),ntuple(_->NoTangent(), length(args))...) - end - end - sol,forward_sensitivity_backpass -end - -function DiffEqBase._concrete_solve_adjoint(prob,alg,sensealg::ZygoteAdjoint, - u0,p,originator::SciMLBase.ADOriginator,args...;kwargs...) - Zygote.pullback((u0,p)->solve(prob,alg,args...;u0=u0,p=p, - sensealg = SensitivityADPassThrough(),kwargs...),u0,p) -end - -function DiffEqBase._concrete_solve_adjoint(prob,alg,sensealg::TrackerAdjoint, - u0,p,originator::SciMLBase.ADOriginator,args...; - kwargs...) - - local sol - function tracker_adjoint_forwardpass(_u0,_p) - - if (convert_tspan(sensealg) === nothing && ( - (haskey(kwargs,:callback) && has_continuous_callback(kwargs[:callback])))) || - (convert_tspan(sensealg) !== nothing && convert_tspan(sensealg)) - _tspan = convert.(eltype(_p),prob.tspan) - else - _tspan = prob.tspan - end - - if DiffEqBase.isinplace(prob) - # use Array{TrackedReal} for mutation to work - # Recurse to all Array{TrackedArray} - _prob = remake(prob,u0=map(identity,_u0),p=_p,tspan=_tspan) - else - # use TrackedArray for efficiency of the tape - if typeof(prob) <: Union{SciMLBase.AbstractDDEProblem,SciMLBase.AbstractDAEProblem,SciMLBase.AbstractSDDEProblem} - _f = function (u,p,h,t) # For DDE, but also works for (du,u,p,t) DAE - out = prob.f(u,p,h,t) - if out isa TrackedArray - return out - else - Tracker.collect(out) - end - end - - # Only define `g` for the stochastic ones - if typeof(prob) <: SciMLBase.AbstractSDEProblem - _g = function (u,p,h,t) - out = prob.g(u,p,h,t) - if out isa TrackedArray - return out - else - Tracker.collect(out) - end - end - _prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f){false,true}(_f,_g),u0=_u0,p=_p,tspan=_tspan) - else - _prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f){false,true}(_f),u0=_u0,p=_p,tspan=_tspan) - end - elseif typeof(prob) <: Union{SciMLBase.AbstractODEProblem,SciMLBase.AbstractSDEProblem} - _f = function (u,p,t) - out = prob.f(u,p,t) - if out isa TrackedArray - return out - else - Tracker.collect(out) - end - end - if typeof(prob) <: SciMLBase.AbstractSDEProblem - _g = function (u,p,t) - out = prob.g(u,p,t) - if out isa TrackedArray - return out - else - Tracker.collect(out) - end - end - _prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f){false,true}(_f,_g),u0=_u0,p=_p,tspan=_tspan) - else - _prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f){false,true}(_f),u0=_u0,p=_p,tspan=_tspan) - end - else - error("TrackerAdjont does not currently support the specified problem type. Please open an issue.") - end - end - sol = solve(_prob,alg,args...;sensealg=DiffEqBase.SensitivityADPassThrough(),kwargs...) - - if typeof(sol.u[1]) <: Array - return Array(sol) - else - tmp = vec(sol.u[1]) - for i in 2:length(sol.u) - tmp = hcat(tmp,vec(sol.u[i])) - end - return reshape(tmp,size(sol.u[1])...,length(sol.u)) - end - #adapt(typeof(u0),arr) - sol - end - - out,pullback = Tracker.forward(tracker_adjoint_forwardpass,u0,p) - function tracker_adjoint_backpass(ybar) - tmp = if eltype(ybar) <: Number && typeof(u0) <: Array - Array(ybar) - elseif eltype(ybar) <: Number # CuArray{Floats} - ybar - elseif typeof(ybar[1]) <: Array - return Array(ybar) - else - tmp = vec(ybar.u[1]) - for i in 2:length(ybar.u) - tmp = hcat(tmp,vec(ybar.u[i])) - end - return reshape(tmp,size(ybar.u[1])...,length(ybar.u)) - end - u0bar, pbar = pullback(tmp) - _u0bar = u0bar isa Tracker.TrackedArray ? Tracker.data(u0bar) : Tracker.data.(u0bar) - - if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(),NoTangent(),_u0bar,Tracker.data(pbar),NoTangent(),ntuple(_->NoTangent(), length(args))...) - else - (NoTangent(),NoTangent(),NoTangent(),_u0bar,Tracker.data(pbar),NoTangent(),ntuple(_->NoTangent(), length(args))...) - end - end - - u = u0 isa Tracker.TrackedArray ? Tracker.data.(sol.u) : Tracker.data.(Tracker.data.(sol.u)) - DiffEqBase.sensitivity_solution(sol,u,Tracker.data.(sol.t)),tracker_adjoint_backpass -end - -const REVERSEDIFF_ADJOINT_GPU_COMPATABILITY_MESSAGE = -""" -ReverseDiffAdjoint is not compatible GPU-based array types. Use a different -sensitivity analysis method, like InterpolatingAdjoint or TrackerAdjoint, -in order to combine with GPUs. -""" - -struct ReverseDiffGPUStateCompatibilityError <: Exception end - -function Base.showerror(io::IO, e::ReverseDiffGPUStateCompatibilityError) - print(io, FORWARDDIFF_SENSITIVITY_PARAMETER_COMPATABILITY_MESSAGE) -end - -function DiffEqBase._concrete_solve_adjoint(prob,alg,sensealg::ReverseDiffAdjoint, - u0,p,originator::SciMLBase.ADOriginator,args...;kwargs...) - - if typeof(u0) isa GPUArraysCore.AbstractGPUArray - throw(ReverseDiffGPUStateCompatibilityError()) - end - - t = eltype(prob.tspan)[] - u = typeof(u0)[] - - local sol - - function reversediff_adjoint_forwardpass(_u0,_p) - - if (convert_tspan(sensealg) === nothing && ( - (haskey(kwargs,:callback) && has_continuous_callback(kwargs[:callback])))) || - (convert_tspan(sensealg) !== nothing && convert_tspan(sensealg)) - _tspan = convert.(eltype(_p),prob.tspan) - else - _tspan = prob.tspan - end - - if DiffEqBase.isinplace(prob) - # use Array{TrackedReal} for mutation to work - # Recurse to all Array{TrackedArray} - _prob = remake(prob,u0=reshape([x for x in _u0],size(_u0)),p=_p,tspan=_tspan) - else - # use TrackedArray for efficiency of the tape - _f(args...) = reduce(vcat,prob.f(args...)) - if prob isa SDEProblem - _g(args...) = reduce(vcat,prob.g(args...)) - _prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f){SciMLBase.isinplace(prob),true}(_f,_g),u0=_u0,p=_p,tspan=_tspan) - else - _prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f){SciMLBase.isinplace(prob),true}(_f),u0=_u0,p=_p,tspan=_tspan) - end - end - - sol = solve(_prob,alg,args...;sensealg=DiffEqBase.SensitivityADPassThrough(),kwargs...) - t = sol.t - if DiffEqBase.isinplace(prob) - u = map.(ReverseDiff.value,sol.u) - else - u = map(ReverseDiff.value,sol.u) - end - Array(sol) - end - - tape = ReverseDiff.GradientTape(reversediff_adjoint_forwardpass,(u0, p)) - tu, tp = ReverseDiff.input_hook(tape) - output = ReverseDiff.output_hook(tape) - ReverseDiff.value!(tu, u0) - typeof(p) <: DiffEqBase.NullParameters || ReverseDiff.value!(tp, p) - ReverseDiff.forward_pass!(tape) - function reversediff_adjoint_backpass(ybar) - _ybar = if ybar isa VectorOfArray - Array(ybar) - elseif eltype(ybar) <: AbstractArray - Array(VectorOfArray(ybar)) - else - ybar - end - ReverseDiff.increment_deriv!(output, _ybar) - ReverseDiff.reverse_pass!(tape) - - if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(),NoTangent(),ReverseDiff.deriv(tu),ReverseDiff.deriv(tp),NoTangent(),ntuple(_->NoTangent(), length(args))...) - else - (NoTangent(),NoTangent(),NoTangent(),ReverseDiff.deriv(tu),ReverseDiff.deriv(tp),NoTangent(),ntuple(_->NoTangent(), length(args))...) - end - end - Array(VectorOfArray(u)),reversediff_adjoint_backpass -end - - -function DiffEqBase._concrete_solve_adjoint(prob,alg, - sensealg::AbstractShadowingSensitivityAlgorithm, - u0,p,originator::SciMLBase.ADOriginator,args...;save_start=true,save_end=true, - saveat = eltype(prob.tspan)[], - save_idxs = nothing, - kwargs...) - - if haskey(kwargs, :callback) - error("Sensitivity analysis based on Least Squares Shadowing is not compatible with callbacks. Please select another `sensealg`.") - else - _prob = remake(prob,u0=u0,p=p) - end - - sol = solve(_prob,alg,args...;save_start=save_start,save_end=save_end,saveat=saveat,kwargs...) - - if saveat isa Number - if _prob.tspan[2] > _prob.tspan[1] - ts = _prob.tspan[1]:convert(typeof(_prob.tspan[2]),abs(saveat)):_prob.tspan[2] - else - ts = _prob.tspan[2]:convert(typeof(_prob.tspan[2]),abs(saveat)):_prob.tspan[1] - end - _out = sol(ts) - out = if save_idxs === nothing - out = DiffEqBase.sensitivity_solution(sol,_out.u,sol.t) - else - out = DiffEqBase.sensitivity_solution(sol,[_out[i][save_idxs] for i in 1:length(_out)],ts) - end - # only_end - (length(ts) == 1 && ts[1] == _prob.tspan[2]) && error("Sensitivity analysis based on Least Squares Shadowing requires a long-time averaged quantity.") - elseif isempty(saveat) - no_start = !save_start - no_end = !save_end - sol_idxs = 1:length(sol) - no_start && (sol_idxs = sol_idxs[2:end]) - no_end && (sol_idxs = sol_idxs[1:end-1]) - only_end = length(sol_idxs) <= 1 - _u = sol.u[sol_idxs] - u = save_idxs === nothing ? _u : [x[save_idxs] for x in _u] - ts = sol.t[sol_idxs] - out = DiffEqBase.sensitivity_solution(sol,u,ts) - else - _saveat = saveat isa Array ? sort(saveat) : saveat # for minibatching - ts = _saveat - _out = sol(ts) - - out = if save_idxs === nothing - out = DiffEqBase.sensitivity_solution(sol,_out.u,ts) - else - out = DiffEqBase.sensitivity_solution(sol,[_out[i][save_idxs] for i in 1:length(_out)],ts) - end - # only_end - (length(ts) == 1 && ts[1] == _prob.tspan[2]) && error("Sensitivity analysis based on Least Squares Shadowing requires a long-time averaged quantity.") - end - - _save_idxs = save_idxs === nothing ? Colon() : save_idxs - - function adjoint_sensitivity_backpass(Δ) - function df(_out, u, p, t, i) - if typeof(Δ) <: AbstractArray{<:AbstractArray} || typeof(Δ) <: DESolution - if typeof(_save_idxs) <: Number - _out[_save_idxs] = Δ[i][_save_idxs] - elseif _save_idxs isa Colon - vec(_out) .= vec(Δ[i]) - else - vec(@view(_out[_save_idxs])) .= vec(Δ[i][_save_idxs]) - end - else - if typeof(_save_idxs) <: Number - _out[_save_idxs] = adapt(DiffEqBase.parameterless_type(u0),reshape(Δ, prod(size(Δ)[1:end-1]), size(Δ)[end])[_save_idxs, i]) - elseif _save_idxs isa Colon - vec(_out) .= vec(adapt(DiffEqBase.parameterless_type(u0),reshape(Δ, prod(size(Δ)[1:end-1]), size(Δ)[end])[:, i])) - else - vec(@view(_out[_save_idxs])) .= vec(adapt(DiffEqBase.parameterless_type(u0),reshape(Δ, prod(size(Δ)[1:end-1]), size(Δ)[end])[:, i])) - end - end - end - - if sensealg isa ForwardLSS - lss_problem = ForwardLSSProblem(sol, sensealg, t=ts, dg_discrete=df) - dp = shadow_forward(lss_problem) - elseif sensealg isa AdjointLSS - adjointlss_problem = AdjointLSSProblem(sol, sensealg, t=ts, dg_discrete=df) - dp = shadow_adjoint(adjointlss_problem) - elseif sensealg isa NILSS - nilss_prob = NILSSProblem(_prob, sensealg, t=ts, dg_discrete=df) - dp = shadow_forward(nilss_prob,alg) - elseif sensealg isa NILSAS - nilsas_prob = NILSASProblem(_prob, sensealg, t=ts, dg_discrete=df) - dp = shadow_adjoint(nilsas_prob,alg) - else - error("No concrete_solve implementation found for sensealg `$sensealg`. Did you spell the sensitivity algorithm correctly? Please report this error.") - end - - if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(),NoTangent(),NoTangent(),dp,NoTangent(),ntuple(_->NoTangent(), length(args))...) - else - (NoTangent(),NoTangent(),NoTangent(),NoTangent(),dp,NoTangent(),ntuple(_->NoTangent(), length(args))...) - end - end - out, adjoint_sensitivity_backpass -end - -function DiffEqBase._concrete_solve_adjoint(prob::Union{NonlinearProblem,SteadyStateProblem}, - alg,sensealg::SteadyStateAdjoint, - u0,p,originator::SciMLBase.ADOriginator,args...;save_idxs = nothing, kwargs...) - - _prob = remake(prob,u0=u0,p=p) - sol = solve(_prob,alg,args...;kwargs...) - _save_idxs = save_idxs === nothing ? Colon() : save_idxs - - if save_idxs === nothing - out = sol - else - out = DiffEqBase.sensitivity_solution(sol,sol[_save_idxs]) - end - - function steadystatebackpass(Δ) - # Δ = dg/dx or diffcache.dg_val - # del g/del p = 0 - dp = adjoint_sensitivities(sol,alg;sensealg=sensealg,g=nothing,dg=Δ,save_idxs=save_idxs) - - if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(),NoTangent(),NoTangent(),dp,NoTangent(),ntuple(_->NoTangent(), length(args))...) - else - (NoTangent(),NoTangent(),NoTangent(),NoTangent(),dp,NoTangent(),ntuple(_->NoTangent(), length(args))...) - end - end - out, steadystatebackpass -end - -function fix_endpoints(sensealg,sol,ts) - @warn "Endpoints do not match. Return code: $(sol.retcode). Likely your time range is not a multiple of `saveat`. sol.t[end]: $(sol.t[end]), ts[end]: $(ts[end])" - ts = collect(ts) - push!(ts, sol.t[end]) -end +## High level + +# Here is where we can add a default algorithm for computing sensitivities +# Based on problem information! + +function inplace_vjp(prob, u0, p, verbose) + du = copy(u0) + ez = try + Enzyme.autodiff(Enzyme.Duplicated(du, du), + copy(u0), copy(p), prob.tspan[1]) do out, u, _p, t + prob.f(out, u, _p, t) + nothing + end + true + catch + false + end + if ez + return EnzymeVJP() + end + + # Determine if we can compile ReverseDiff + compile = try + if DiffEqBase.isinplace(prob) + !hasbranching(prob.f, copy(u0), u0, p, prob.tspan[1]) + else + !hasbranching(prob.f, u0, p, prob.tspan[1]) + end + catch + false + end + + vjp = try + ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u, p, t + du1 = similar(u, size(u)) + prob.f(du1, u, p, first(t)) + return vec(du1) + end + ReverseDiffVJP(compile) + catch + false + end + return vjp +end + +function automatic_sensealg_choice(prob::Union{ODEProblem, SDEProblem}, u0, p, verbose) + default_sensealg = if p !== DiffEqBase.NullParameters() && + !(eltype(u0) <: ForwardDiff.Dual) && + !(eltype(p) <: ForwardDiff.Dual) && + !(eltype(u0) <: Complex) && + !(eltype(p) <: Complex) && + length(u0) + length(p) <= 100 + ForwardDiffSensitivity() + elseif u0 isa GPUArraysCore.AbstractGPUArray || !DiffEqBase.isinplace(prob) + # only Zygote is GPU compatible and fast + # so if out-of-place, try Zygote + if p === nothing || p === DiffEqBase.NullParameters() + # QuadratureAdjoint skips all p calculations until the end + # So it's the fastest when there are no parameters + QuadratureAdjoint(autojacvec = ZygoteVJP()) + else + InterpolatingAdjoint(autojacvec = ZygoteVJP()) + end + else + vjp = inplace_vjp(prob, u0, p, verbose) + if p === nothing || p === DiffEqBase.NullParameters() + QuadratureAdjoint(autojacvec = vjp) + else + InterpolatingAdjoint(autojacvec = vjp) + end + end + return default_sensealg +end + +function automatic_sensealg_choice(prob::Union{NonlinearProblem, SteadyStateProblem}, u0, p, + verbose) + default_sensealg = if u0 isa GPUArraysCore.AbstractGPUArray || + !DiffEqBase.isinplace(prob) + # autodiff = false because forwarddiff fails on many GPU kernels + # this only effects the Jacobian calculation and is same computation order + SteadyStateAdjoint(autodiff = false, autojacvec = ZygoteVJP()) + else + vjp = inplace_vjp(prob, u0, p, verbose) + SteadyStateAdjoint(autojacvec = vjp) + end + return default_sensealg +end + +function DiffEqBase._concrete_solve_adjoint(prob::Union{ODEProblem, SDEProblem}, + alg, sensealg::Nothing, u0, p, + originator::SciMLBase.ADOriginator, args...; + verbose = true, kwargs...) + if haskey(kwargs, :callback) + has_cb = kwargs[:callback] !== nothing + else + has_cb = false + end + default_sensealg = automatic_sensealg_choice(prob, u0, p, verbose) + if has_cb && typeof(default_sensealg) <: AbstractAdjointSensitivityAlgorithm + default_sensealg = setvjp(default_sensealg, ReverseDiffVJP()) + end + DiffEqBase._concrete_solve_adjoint(prob, alg, default_sensealg, u0, p, + originator::SciMLBase.ADOriginator, args...; verbose, + kwargs...) +end + +function DiffEqBase._concrete_solve_adjoint(prob::Union{NonlinearProblem, SteadyStateProblem + }, alg, + sensealg::Nothing, u0, p, + originator::SciMLBase.ADOriginator, args...; + verbose = true, kwargs...) + default_sensealg = automatic_sensealg_choice(prob, u0, p, verbose) + DiffEqBase._concrete_solve_adjoint(prob, alg, default_sensealg, u0, p, + originator::SciMLBase.ADOriginator, args...; verbose, + kwargs...) +end + +function DiffEqBase._concrete_solve_adjoint(prob::Union{DiscreteProblem, DDEProblem, + SDDEProblem, DAEProblem}, + alg, sensealg::Nothing, + u0, p, originator::SciMLBase.ADOriginator, + args...; kwargs...) + if length(u0) + length(p) > 100 + default_sensealg = ReverseDiffAdjoint() + else + default_sensealg = ForwardDiffSensitivity() + end + DiffEqBase._concrete_solve_adjoint(prob, alg, default_sensealg, u0, p, + originator::SciMLBase.ADOriginator, args...; + kwargs...) +end + +function DiffEqBase._concrete_solve_adjoint(prob, alg, + sensealg::AbstractAdjointSensitivityAlgorithm, + u0, p, originator::SciMLBase.ADOriginator, + args...; save_start = true, save_end = true, + saveat = eltype(prob.tspan)[], + save_idxs = nothing, + kwargs...) + if !(typeof(p) <: Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || + (p isa AbstractArray && !Base.isconcretetype(eltype(p))) + throw(AdjointSensitivityParameterCompatibilityError()) + end + + # Remove saveat, etc. from kwargs since it's handled separately + # and letting it jump back in there can break the adjoint + kwargs_prob = NamedTuple(filter(x -> x[1] != :saveat && x[1] != :save_start && + x[1] != :save_end && x[1] != :save_idxs, + prob.kwargs)) + + if haskey(kwargs, :callback) + cb = track_callbacks(CallbackSet(kwargs[:callback]), prob.tspan[1], prob.u0, prob.p, + sensealg) + _prob = remake(prob; u0 = u0, p = p, kwargs = merge(kwargs_prob, (; callback = cb))) + else + cb = nothing + _prob = remake(prob; u0 = u0, p = p, kwargs = kwargs_prob) + end + + # Remove callbacks, saveat, etc. from kwargs since it's handled separately + kwargs_fwd = NamedTuple{Base.diff_names(Base._nt_names(values(kwargs)), (:callback,))}(values(kwargs)) + + # Capture the callback_adj for the reverse pass and remove both callbacks + kwargs_adj = NamedTuple{ + Base.diff_names(Base._nt_names(values(kwargs)), + (:callback_adj, :callback))}(values(kwargs)) + isq = sensealg isa QuadratureAdjoint + if typeof(sensealg) <: BacksolveAdjoint + sol = solve(_prob, alg, args...; save_noise = true, + save_start = save_start, save_end = save_end, + saveat = saveat, kwargs_fwd...) + elseif ischeckpointing(sensealg) + sol = solve(_prob, alg, args...; save_noise = true, + save_start = true, save_end = true, + saveat = saveat, kwargs_fwd...) + else + sol = solve(_prob, alg, args...; save_noise = true, save_start = true, + save_end = true, kwargs_fwd...) + end + + # Force `save_start` and `save_end` in the forward pass This forces the + # solver to do the backsolve all the way back to `u0` Since the start aliases + # `_prob.u0`, this doesn't actually use more memory But it cleans up the + # implementation and makes `save_start` and `save_end` arg safe. + if typeof(sensealg) <: BacksolveAdjoint + # Saving behavior unchanged + ts = sol.t + only_end = length(ts) == 1 && ts[1] == _prob.tspan[2] + out = DiffEqBase.sensitivity_solution(sol, sol.u, ts) + elseif saveat isa Number + if _prob.tspan[2] > _prob.tspan[1] + ts = _prob.tspan[1]:convert(typeof(_prob.tspan[2]), abs(saveat)):_prob.tspan[2] + else + ts = _prob.tspan[2]:convert(typeof(_prob.tspan[2]), abs(saveat)):_prob.tspan[1] + end + # if _prob.tspan[2]-_prob.tspan[1] is not a multiple of saveat, one looses the last ts value + sol.t[end] !== ts[end] && (ts = fix_endpoints(sensealg, sol, ts)) + if cb === nothing + _out = sol(ts) + else + _, duplicate_iterator_times = separate_nonunique(sol.t) + _out, ts = out_and_ts(ts, duplicate_iterator_times, sol) + end + + out = if save_idxs === nothing + out = DiffEqBase.sensitivity_solution(sol, _out.u, ts) + else + out = DiffEqBase.sensitivity_solution(sol, + [_out[i][save_idxs] + for i in 1:length(_out)], ts) + end + only_end = length(ts) == 1 && ts[1] == _prob.tspan[2] + elseif isempty(saveat) + no_start = !save_start + no_end = !save_end + sol_idxs = 1:length(sol) + no_start && (sol_idxs = sol_idxs[2:end]) + no_end && (sol_idxs = sol_idxs[1:(end - 1)]) + only_end = length(sol_idxs) <= 1 + _u = sol.u[sol_idxs] + u = save_idxs === nothing ? _u : [x[save_idxs] for x in _u] + ts = sol.t[sol_idxs] + out = DiffEqBase.sensitivity_solution(sol, u, ts) + else + _saveat = saveat isa Array ? sort(saveat) : saveat # for minibatching + if cb === nothing + _saveat = eltype(_saveat) <: typeof(prob.tspan[2]) ? + convert.(typeof(_prob.tspan[2]), _saveat) : _saveat + ts = _saveat + _out = sol(ts) + else + _ts, duplicate_iterator_times = separate_nonunique(sol.t) + _out, ts = out_and_ts(_saveat, duplicate_iterator_times, sol) + end + + out = if save_idxs === nothing + out = DiffEqBase.sensitivity_solution(sol, _out.u, ts) + else + out = DiffEqBase.sensitivity_solution(sol, + [_out[i][save_idxs] + for i in 1:length(_out)], ts) + end + only_end = length(ts) == 1 && ts[1] == _prob.tspan[2] + end + + _save_idxs = save_idxs === nothing ? Colon() : save_idxs + + function adjoint_sensitivity_backpass(Δ) + function df(_out, u, p, t, i) + outtype = typeof(_out) <: SubArray ? + DiffEqBase.parameterless_type(_out.parent) : + DiffEqBase.parameterless_type(_out) + if only_end + eltype(Δ) <: NoTangent && return + if typeof(Δ) <: AbstractArray{<:AbstractArray} && length(Δ) == 1 && i == 1 + # user did sol[end] on only_end + if typeof(_save_idxs) <: Number + x = vec(Δ[1]) + _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) + elseif _save_idxs isa Colon + vec(_out) .= adapt(outtype, vec(Δ[1])) + else + vec(@view(_out[_save_idxs])) .= adapt(outtype, + vec(Δ[1])[_save_idxs]) + end + else + Δ isa NoTangent && return + if typeof(_save_idxs) <: Number + x = vec(Δ) + _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) + elseif _save_idxs isa Colon + vec(_out) .= adapt(outtype, vec(Δ)) + else + x = vec(Δ) + vec(@view(_out[_save_idxs])) .= adapt(outtype, @view(x[_save_idxs])) + end + end + else + !Base.isconcretetype(eltype(Δ)) && + (Δ[i] isa NoTangent || eltype(Δ) <: NoTangent) && return + if typeof(Δ) <: AbstractArray{<:AbstractArray} || typeof(Δ) <: DESolution + x = Δ[i] + if typeof(_save_idxs) <: Number + _out[_save_idxs] = @view(x[_save_idxs]) + elseif _save_idxs isa Colon + vec(_out) .= vec(x) + else + vec(@view(_out[_save_idxs])) .= vec(@view(x[_save_idxs])) + end + else + if typeof(_save_idxs) <: Number + _out[_save_idxs] = adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[_save_idxs, i]) + elseif _save_idxs isa Colon + vec(_out) .= vec(adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) + else + vec(@view(_out[_save_idxs])) .= vec(adapt(outtype, + reshape(Δ, + prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, + i])) + end + end + end + end + + if haskey(kwargs_adj, :callback_adj) + cb2 = CallbackSet(cb, kwargs[:callback_adj]) + else + cb2 = cb + end + + du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, dg_discrete = df, + sensealg = sensealg, + callback = cb2, + kwargs_adj...) + + du0 = reshape(du0, size(u0)) + dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : + reshape(dp', size(p)) + + if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + (NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + else + (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end + end + out, adjoint_sensitivity_backpass +end + +# Prefer this route since it works better with callback AD +function DiffEqBase._concrete_solve_adjoint(prob, alg, + sensealg::AbstractForwardSensitivityAlgorithm, + u0, p, originator::SciMLBase.ADOriginator, + args...; + save_idxs = nothing, + kwargs...) + if !(typeof(p) <: Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || + (p isa AbstractArray && !Base.isconcretetype(eltype(p))) + throw(ForwardSensitivityParameterCompatibilityError()) + end + + if p isa AbstractArray && eltype(p) <: ForwardDiff.Dual && + !(eltype(u0) <: ForwardDiff.Dual) + # Handle double differentiation case + u0 = eltype(p).(u0) + end + _prob = ODEForwardSensitivityProblem(prob.f, u0, prob.tspan, p, sensealg) + sol = solve(_prob, alg, args...; kwargs...) + _, du = extract_local_sensitivities(sol, sensealg, Val(true)) + + u = if save_idxs === nothing + [reshape(sol[i][1:length(u0)], size(u0)) for i in 1:length(sol)] + else + [sol[i][_save_idxs] for i in 1:length(sol)] + end + out = DiffEqBase.sensitivity_solution(sol, u, sol.t) + + function forward_sensitivity_backpass(Δ) + adj = sum(eachindex(du)) do i + J = du[i] + if Δ isa AbstractVector || Δ isa DESolution || Δ isa AbstractVectorOfArray + v = Δ[i] + elseif Δ isa AbstractMatrix + v = @view Δ[:, i] + else + v = @view Δ[.., i] + end + J'vec(v) + end + + du0 = @not_implemented("ForwardSensitivity does not differentiate with respect to u0. Change your sensealg.") + + if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + (NoTangent(), NoTangent(), du0, adj, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + else + (NoTangent(), NoTangent(), NoTangent(), du0, adj, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end + end + out, forward_sensitivity_backpass +end + +function DiffEqBase._concrete_solve_forward(prob, alg, + sensealg::AbstractForwardSensitivityAlgorithm, + u0, p, originator::SciMLBase.ADOriginator, + args...; save_idxs = nothing, + kwargs...) + _prob = ODEForwardSensitivityProblem(prob.f, u0, prob.tspan, p, sensealg) + sol = solve(_prob, args...; kwargs...) + u, du = extract_local_sensitivities(sol, Val(true)) + _save_idxs = save_idxs === nothing ? (1:length(u0)) : save_idxs + out = DiffEqBase.sensitivity_solution(sol, + [ForwardDiff.value.(sol[i][_save_idxs]) + for i in 1:length(sol)], sol.t) + function _concrete_solve_pushforward(Δself, ::Nothing, ::Nothing, x3, Δp, args...) + x3 !== nothing && error("Pushforward currently requires no u0 derivatives") + du * Δp + end + out, _concrete_solve_pushforward +end + +const FORWARDDIFF_SENSITIVITY_PARAMETER_COMPATABILITY_MESSAGE = """ + ForwardDiffSensitivity assumes the `AbstractArray` interface for `p`. Thus while + DifferentialEquations.jl can support any parameter struct type, usage + with ForwardDiffSensitivity requires that `p` could be a valid + type for being the initial condition `u0` of an array. This means that + many simple types, such as `Tuple`s and `NamedTuple`s, will work as + parameters in normal contexts but will fail during ForwardDiffSensitivity + construction. To work around this issue for complicated cases like nested structs, + look into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl + or ComponentArrays.jl. + """ + +struct ForwardDiffSensitivityParameterCompatibilityError <: Exception end + +function Base.showerror(io::IO, e::ForwardDiffSensitivityParameterCompatibilityError) + print(io, FORWARDDIFF_SENSITIVITY_PARAMETER_COMPATABILITY_MESSAGE) +end + +# Generic Fallback for ForwardDiff +function DiffEqBase._concrete_solve_adjoint(prob, alg, + sensealg::ForwardDiffSensitivity{CS, CTS}, + u0, p, originator::SciMLBase.ADOriginator, + args...; saveat = eltype(prob.tspan)[], + kwargs...) where {CS, CTS} + if !(typeof(p) <: Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || + (p isa AbstractArray && !Base.isconcretetype(eltype(p))) + throw(ForwardDiffSensitivityParameterCompatibilityError()) + end + + if saveat isa Number + _saveat = prob.tspan[1]:saveat:prob.tspan[2] + else + _saveat = saveat + end + + sol = solve(remake(prob, p = p, u0 = u0), alg, args...; saveat = _saveat, kwargs...) + + # saveat values + # seems overcomplicated, but see the PR + if length(sol.t) == 1 + ts = sol.t + else + ts = eltype(sol.t)[] + if sol.t[2] != sol.t[1] + push!(ts, sol.t[1]) + end + for i in 2:(length(sol.t) - 1) + if sol.t[i] != sol.t[i + 1] && sol.t[i] != sol.t[i - 1] + push!(ts, sol.t[i]) + end + end + if sol.t[end] != sol.t[end - 1] + push!(ts, sol.t[end]) + end + end + + function forward_sensitivity_backpass(Δ) + dp = @thunk begin + chunk_size = if CS === 0 && length(p) < 12 + length(p) + elseif CS !== 0 + CS + else + 12 + end + + num_chunks = length(p) ÷ chunk_size + num_chunks * chunk_size != length(p) && (num_chunks += 1) + + pparts = typeof(p[1:1])[] + for j in 0:(num_chunks - 1) + local chunk + if ((j + 1) * chunk_size) <= length(p) + chunk = ((j * chunk_size + 1):((j + 1) * chunk_size)) + pchunk = vec(p)[chunk] + pdualpart = seed_duals(pchunk, prob.f, ForwardDiff.Chunk{chunk_size}()) + else + chunk = ((j * chunk_size + 1):length(p)) + pchunk = vec(p)[chunk] + pdualpart = seed_duals(pchunk, prob.f, + ForwardDiff.Chunk{length(chunk)}()) + end + + pdualvec = if j == 0 + vcat(pdualpart, p[((j + 1) * chunk_size + 1):end]) + elseif j == num_chunks - 1 + vcat(p[1:(j * chunk_size)], pdualpart) + else + vcat(p[1:(j * chunk_size)], pdualpart, + p[(((j + 1) * chunk_size) + 1):end]) + end + + pdual = ArrayInterfaceCore.restructure(p, pdualvec) + u0dual = convert.(eltype(pdualvec), u0) + + if (convert_tspan(sensealg) === nothing && ((haskey(kwargs, :callback) && + has_continuous_callback(kwargs[:callback])))) || + (convert_tspan(sensealg) !== nothing && convert_tspan(sensealg)) + tspandual = convert.(eltype(pdual), prob.tspan) + else + tspandual = prob.tspan + end + + if typeof(prob.f) <: ODEFunction && prob.f.jac_prototype !== nothing + _f = ODEFunction{SciMLBase.isinplace(prob.f), true}(prob.f, + jac_prototype = convert.(eltype(u0dual), + prob.f.jac_prototype)) + elseif typeof(prob.f) <: SDEFunction && prob.f.jac_prototype !== nothing + _f = SDEFunction{SciMLBase.isinplace(prob.f), true}(prob.f, + jac_prototype = convert.(eltype(u0dual), + prob.f.jac_prototype)) + else + _f = prob.f + end + _prob = remake(prob, f = _f, u0 = u0dual, p = pdual, tspan = tspandual) + + if _prob isa SDEProblem + _prob.noise_rate_prototype !== nothing && (_prob = remake(_prob, + noise_rate_prototype = convert.(eltype(pdual), + _prob.noise_rate_prototype))) + end + + if saveat isa Number + _saveat = prob.tspan[1]:saveat:prob.tspan[2] + else + _saveat = saveat + end + + _sol = solve(_prob, alg, args...; saveat = ts, kwargs...) + _, du = extract_local_sensitivities(_sol, sensealg, Val(true)) + + _dp = sum(eachindex(du)) do i + J = du[i] + if Δ isa AbstractVector || Δ isa DESolution || + Δ isa AbstractVectorOfArray + v = Δ[i] + elseif Δ isa AbstractMatrix + v = @view Δ[:, i] + else + v = @view Δ[.., i] + end + if !(Δ isa NoTangent) + ForwardDiff.value.(J'vec(v)) + else + zero(p) + end + end + push!(pparts, vec(_dp)) + end + ArrayInterfaceCore.restructure(p, reduce(vcat, pparts)) + end + + du0 = @thunk begin + chunk_size = if CS === 0 && length(u0) < 12 + length(u0) + elseif CS !== 0 + CS + else + 12 + end + + num_chunks = length(u0) ÷ chunk_size + num_chunks * chunk_size != length(u0) && (num_chunks += 1) + + du0parts = typeof(u0[1:1])[] + for j in 0:(num_chunks - 1) + local chunk + if ((j + 1) * chunk_size) <= length(u0) + chunk = ((j * chunk_size + 1):((j + 1) * chunk_size)) + u0chunk = vec(u0)[chunk] + u0dualpart = seed_duals(u0chunk, prob.f, + ForwardDiff.Chunk{chunk_size}()) + else + chunk = ((j * chunk_size + 1):length(u0)) + u0chunk = vec(u0)[chunk] + u0dualpart = seed_duals(u0chunk, prob.f, + ForwardDiff.Chunk{length(chunk)}()) + end + + u0dualvec = if j == 0 + vcat(u0dualpart, u0[((j + 1) * chunk_size + 1):end]) + elseif j == num_chunks - 1 + vcat(u0[1:(j * chunk_size)], u0dualpart) + else + vcat(u0[1:(j * chunk_size)], u0dualpart, + u0[(((j + 1) * chunk_size) + 1):end]) + end + + u0dual = ArrayInterfaceCore.restructure(u0, u0dualvec) + pdual = convert.(eltype(u0dual), p) + + if (convert_tspan(sensealg) === nothing && ((haskey(kwargs, :callback) && + has_continuous_callback(kwargs[:callback])))) || + (convert_tspan(sensealg) !== nothing && convert_tspan(sensealg)) + tspandual = convert.(eltype(pdual), prob.tspan) + else + tspandual = prob.tspan + end + + if typeof(prob.f) <: ODEFunction && prob.f.jac_prototype !== nothing + _f = ODEFunction{SciMLBase.isinplace(prob.f), true}(prob.f, + jac_prototype = convert.(eltype(pdual), + prob.f.jac_prototype)) + elseif typeof(prob.f) <: SDEFunction && prob.f.jac_prototype !== nothing + _f = SDEFunction{SciMLBase.isinplace(prob.f), true}(prob.f, + jac_prototype = convert.(eltype(pdual), + prob.f.jac_prototype)) + else + _f = prob.f + end + _prob = remake(prob, f = _f, u0 = u0dual, p = pdual, tspan = tspandual) + + if _prob isa SDEProblem + _prob.noise_rate_prototype !== nothing && (_prob = remake(_prob, + noise_rate_prototype = convert.(eltype(pdual), + _prob.noise_rate_prototype))) + end + + if saveat isa Number + _saveat = prob.tspan[1]:saveat:prob.tspan[2] + else + _saveat = saveat + end + + _sol = solve(_prob, alg, args...; saveat = ts, kwargs...) + _, du = extract_local_sensitivities(_sol, sensealg, Val(true)) + + _du0 = sum(eachindex(du)) do i + J = du[i] + if Δ isa AbstractVector || Δ isa DESolution || + Δ isa AbstractVectorOfArray + v = Δ[i] + elseif Δ isa AbstractMatrix + v = @view Δ[:, i] + else + v = @view Δ[.., i] + end + if !(Δ isa NoTangent) + ForwardDiff.value.(J'vec(v)) + else + zero(u0) + end + end + push!(du0parts, vec(_du0)) + end + ArrayInterfaceCore.restructure(u0, reduce(vcat, du0parts)) + end + + if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + (NoTangent(), NoTangent(), unthunk(du0), unthunk(dp), NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + else + (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end + end + sol, forward_sensitivity_backpass +end + +function DiffEqBase._concrete_solve_adjoint(prob, alg, sensealg::ZygoteAdjoint, + u0, p, originator::SciMLBase.ADOriginator, + args...; kwargs...) + Zygote.pullback((u0, p) -> solve(prob, alg, args...; u0 = u0, p = p, + sensealg = SensitivityADPassThrough(), kwargs...), u0, + p) +end + +function DiffEqBase._concrete_solve_adjoint(prob, alg, sensealg::TrackerAdjoint, + u0, p, originator::SciMLBase.ADOriginator, + args...; + kwargs...) + local sol + function tracker_adjoint_forwardpass(_u0, _p) + if (convert_tspan(sensealg) === nothing && + ((haskey(kwargs, :callback) && has_continuous_callback(kwargs[:callback])))) || + (convert_tspan(sensealg) !== nothing && convert_tspan(sensealg)) + _tspan = convert.(eltype(_p), prob.tspan) + else + _tspan = prob.tspan + end + + if DiffEqBase.isinplace(prob) + # use Array{TrackedReal} for mutation to work + # Recurse to all Array{TrackedArray} + _prob = remake(prob, u0 = map(identity, _u0), p = _p, tspan = _tspan) + else + # use TrackedArray for efficiency of the tape + if typeof(prob) <: + Union{SciMLBase.AbstractDDEProblem, SciMLBase.AbstractDAEProblem, + SciMLBase.AbstractSDDEProblem} + _f = function (u, p, h, t) # For DDE, but also works for (du,u,p,t) DAE + out = prob.f(u, p, h, t) + if out isa TrackedArray + return out + else + Tracker.collect(out) + end + end + + # Only define `g` for the stochastic ones + if typeof(prob) <: SciMLBase.AbstractSDEProblem + _g = function (u, p, h, t) + out = prob.g(u, p, h, t) + if out isa TrackedArray + return out + else + Tracker.collect(out) + end + end + _prob = remake(prob, + f = DiffEqBase.parameterless_type(prob.f){false, true}(_f, + _g), + u0 = _u0, p = _p, tspan = _tspan) + else + _prob = remake(prob, + f = DiffEqBase.parameterless_type(prob.f){false, true}(_f), + u0 = _u0, p = _p, tspan = _tspan) + end + elseif typeof(prob) <: + Union{SciMLBase.AbstractODEProblem, SciMLBase.AbstractSDEProblem} + _f = function (u, p, t) + out = prob.f(u, p, t) + if out isa TrackedArray + return out + else + Tracker.collect(out) + end + end + if typeof(prob) <: SciMLBase.AbstractSDEProblem + _g = function (u, p, t) + out = prob.g(u, p, t) + if out isa TrackedArray + return out + else + Tracker.collect(out) + end + end + _prob = remake(prob, + f = DiffEqBase.parameterless_type(prob.f){false, true}(_f, + _g), + u0 = _u0, p = _p, tspan = _tspan) + else + _prob = remake(prob, + f = DiffEqBase.parameterless_type(prob.f){false, true}(_f), + u0 = _u0, p = _p, tspan = _tspan) + end + else + error("TrackerAdjont does not currently support the specified problem type. Please open an issue.") + end + end + sol = solve(_prob, alg, args...; sensealg = DiffEqBase.SensitivityADPassThrough(), + kwargs...) + + if typeof(sol.u[1]) <: Array + return Array(sol) + else + tmp = vec(sol.u[1]) + for i in 2:length(sol.u) + tmp = hcat(tmp, vec(sol.u[i])) + end + return reshape(tmp, size(sol.u[1])..., length(sol.u)) + end + #adapt(typeof(u0),arr) + sol + end + + out, pullback = Tracker.forward(tracker_adjoint_forwardpass, u0, p) + function tracker_adjoint_backpass(ybar) + tmp = if eltype(ybar) <: Number && typeof(u0) <: Array + Array(ybar) + elseif eltype(ybar) <: Number # CuArray{Floats} + ybar + elseif typeof(ybar[1]) <: Array + return Array(ybar) + else + tmp = vec(ybar.u[1]) + for i in 2:length(ybar.u) + tmp = hcat(tmp, vec(ybar.u[i])) + end + return reshape(tmp, size(ybar.u[1])..., length(ybar.u)) + end + u0bar, pbar = pullback(tmp) + _u0bar = u0bar isa Tracker.TrackedArray ? Tracker.data(u0bar) : Tracker.data.(u0bar) + + if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + (NoTangent(), NoTangent(), _u0bar, Tracker.data(pbar), NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + else + (NoTangent(), NoTangent(), NoTangent(), _u0bar, Tracker.data(pbar), NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end + end + + u = u0 isa Tracker.TrackedArray ? Tracker.data.(sol.u) : + Tracker.data.(Tracker.data.(sol.u)) + DiffEqBase.sensitivity_solution(sol, u, Tracker.data.(sol.t)), tracker_adjoint_backpass +end + +const REVERSEDIFF_ADJOINT_GPU_COMPATABILITY_MESSAGE = """ + ReverseDiffAdjoint is not compatible GPU-based array types. Use a different + sensitivity analysis method, like InterpolatingAdjoint or TrackerAdjoint, + in order to combine with GPUs. + """ + +struct ReverseDiffGPUStateCompatibilityError <: Exception end + +function Base.showerror(io::IO, e::ReverseDiffGPUStateCompatibilityError) + print(io, FORWARDDIFF_SENSITIVITY_PARAMETER_COMPATABILITY_MESSAGE) +end + +function DiffEqBase._concrete_solve_adjoint(prob, alg, sensealg::ReverseDiffAdjoint, + u0, p, originator::SciMLBase.ADOriginator, + args...; kwargs...) + if typeof(u0) isa GPUArraysCore.AbstractGPUArray + throw(ReverseDiffGPUStateCompatibilityError()) + end + + t = eltype(prob.tspan)[] + u = typeof(u0)[] + + local sol + + function reversediff_adjoint_forwardpass(_u0, _p) + if (convert_tspan(sensealg) === nothing && + ((haskey(kwargs, :callback) && has_continuous_callback(kwargs[:callback])))) || + (convert_tspan(sensealg) !== nothing && convert_tspan(sensealg)) + _tspan = convert.(eltype(_p), prob.tspan) + else + _tspan = prob.tspan + end + + if DiffEqBase.isinplace(prob) + # use Array{TrackedReal} for mutation to work + # Recurse to all Array{TrackedArray} + _prob = remake(prob, u0 = reshape([x for x in _u0], size(_u0)), p = _p, + tspan = _tspan) + else + # use TrackedArray for efficiency of the tape + _f(args...) = reduce(vcat, prob.f(args...)) + if prob isa SDEProblem + _g(args...) = reduce(vcat, prob.g(args...)) + _prob = remake(prob, + f = DiffEqBase.parameterless_type(prob.f){ + SciMLBase.isinplace(prob), + true}(_f, _g), + u0 = _u0, p = _p, tspan = _tspan) + else + _prob = remake(prob, + f = DiffEqBase.parameterless_type(prob.f){ + SciMLBase.isinplace(prob), + true}(_f), + u0 = _u0, p = _p, tspan = _tspan) + end + end + + sol = solve(_prob, alg, args...; sensealg = DiffEqBase.SensitivityADPassThrough(), + kwargs...) + t = sol.t + if DiffEqBase.isinplace(prob) + u = map.(ReverseDiff.value, sol.u) + else + u = map(ReverseDiff.value, sol.u) + end + Array(sol) + end + + tape = ReverseDiff.GradientTape(reversediff_adjoint_forwardpass, (u0, p)) + tu, tp = ReverseDiff.input_hook(tape) + output = ReverseDiff.output_hook(tape) + ReverseDiff.value!(tu, u0) + typeof(p) <: DiffEqBase.NullParameters || ReverseDiff.value!(tp, p) + ReverseDiff.forward_pass!(tape) + function reversediff_adjoint_backpass(ybar) + _ybar = if ybar isa VectorOfArray + Array(ybar) + elseif eltype(ybar) <: AbstractArray + Array(VectorOfArray(ybar)) + else + ybar + end + ReverseDiff.increment_deriv!(output, _ybar) + ReverseDiff.reverse_pass!(tape) + + if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + (NoTangent(), NoTangent(), ReverseDiff.deriv(tu), ReverseDiff.deriv(tp), + NoTangent(), ntuple(_ -> NoTangent(), length(args))...) + else + (NoTangent(), NoTangent(), NoTangent(), ReverseDiff.deriv(tu), + ReverseDiff.deriv(tp), NoTangent(), ntuple(_ -> NoTangent(), length(args))...) + end + end + Array(VectorOfArray(u)), reversediff_adjoint_backpass +end + +function DiffEqBase._concrete_solve_adjoint(prob, alg, + sensealg::AbstractShadowingSensitivityAlgorithm, + u0, p, originator::SciMLBase.ADOriginator, + args...; save_start = true, save_end = true, + saveat = eltype(prob.tspan)[], + save_idxs = nothing, + kwargs...) + if haskey(kwargs, :callback) + error("Sensitivity analysis based on Least Squares Shadowing is not compatible with callbacks. Please select another `sensealg`.") + else + _prob = remake(prob, u0 = u0, p = p) + end + + sol = solve(_prob, alg, args...; save_start = save_start, save_end = save_end, + saveat = saveat, kwargs...) + + if saveat isa Number + if _prob.tspan[2] > _prob.tspan[1] + ts = _prob.tspan[1]:convert(typeof(_prob.tspan[2]), abs(saveat)):_prob.tspan[2] + else + ts = _prob.tspan[2]:convert(typeof(_prob.tspan[2]), abs(saveat)):_prob.tspan[1] + end + _out = sol(ts) + out = if save_idxs === nothing + out = DiffEqBase.sensitivity_solution(sol, _out.u, sol.t) + else + out = DiffEqBase.sensitivity_solution(sol, + [_out[i][save_idxs] + for i in 1:length(_out)], ts) + end + # only_end + (length(ts) == 1 && ts[1] == _prob.tspan[2]) && + error("Sensitivity analysis based on Least Squares Shadowing requires a long-time averaged quantity.") + elseif isempty(saveat) + no_start = !save_start + no_end = !save_end + sol_idxs = 1:length(sol) + no_start && (sol_idxs = sol_idxs[2:end]) + no_end && (sol_idxs = sol_idxs[1:(end - 1)]) + only_end = length(sol_idxs) <= 1 + _u = sol.u[sol_idxs] + u = save_idxs === nothing ? _u : [x[save_idxs] for x in _u] + ts = sol.t[sol_idxs] + out = DiffEqBase.sensitivity_solution(sol, u, ts) + else + _saveat = saveat isa Array ? sort(saveat) : saveat # for minibatching + ts = _saveat + _out = sol(ts) + + out = if save_idxs === nothing + out = DiffEqBase.sensitivity_solution(sol, _out.u, ts) + else + out = DiffEqBase.sensitivity_solution(sol, + [_out[i][save_idxs] + for i in 1:length(_out)], ts) + end + # only_end + (length(ts) == 1 && ts[1] == _prob.tspan[2]) && + error("Sensitivity analysis based on Least Squares Shadowing requires a long-time averaged quantity.") + end + + _save_idxs = save_idxs === nothing ? Colon() : save_idxs + + function adjoint_sensitivity_backpass(Δ) + function df(_out, u, p, t, i) + if typeof(Δ) <: AbstractArray{<:AbstractArray} || typeof(Δ) <: DESolution + if typeof(_save_idxs) <: Number + _out[_save_idxs] = Δ[i][_save_idxs] + elseif _save_idxs isa Colon + vec(_out) .= vec(Δ[i]) + else + vec(@view(_out[_save_idxs])) .= vec(Δ[i][_save_idxs]) + end + else + if typeof(_save_idxs) <: Number + _out[_save_idxs] = adapt(DiffEqBase.parameterless_type(u0), + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[_save_idxs, i]) + elseif _save_idxs isa Colon + vec(_out) .= vec(adapt(DiffEqBase.parameterless_type(u0), + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) + else + vec(@view(_out[_save_idxs])) .= vec(adapt(DiffEqBase.parameterless_type(u0), + reshape(Δ, + prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) + end + end + end + + if sensealg isa ForwardLSS + lss_problem = ForwardLSSProblem(sol, sensealg, t = ts, dg_discrete = df) + dp = shadow_forward(lss_problem) + elseif sensealg isa AdjointLSS + adjointlss_problem = AdjointLSSProblem(sol, sensealg, t = ts, dg_discrete = df) + dp = shadow_adjoint(adjointlss_problem) + elseif sensealg isa NILSS + nilss_prob = NILSSProblem(_prob, sensealg, t = ts, dg_discrete = df) + dp = shadow_forward(nilss_prob, alg) + elseif sensealg isa NILSAS + nilsas_prob = NILSASProblem(_prob, sensealg, t = ts, dg_discrete = df) + dp = shadow_adjoint(nilsas_prob, alg) + else + error("No concrete_solve implementation found for sensealg `$sensealg`. Did you spell the sensitivity algorithm correctly? Please report this error.") + end + + if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + (NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + else + (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end + end + out, adjoint_sensitivity_backpass +end + +function DiffEqBase._concrete_solve_adjoint(prob::Union{NonlinearProblem, SteadyStateProblem + }, + alg, sensealg::SteadyStateAdjoint, + u0, p, originator::SciMLBase.ADOriginator, + args...; save_idxs = nothing, kwargs...) + _prob = remake(prob, u0 = u0, p = p) + sol = solve(_prob, alg, args...; kwargs...) + _save_idxs = save_idxs === nothing ? Colon() : save_idxs + + if save_idxs === nothing + out = sol + else + out = DiffEqBase.sensitivity_solution(sol, sol[_save_idxs]) + end + + function steadystatebackpass(Δ) + # Δ = dg/dx or diffcache.dg_val + # del g/del p = 0 + dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, g = nothing, dg = Δ, + save_idxs = save_idxs) + + if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + (NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + else + (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end + end + out, steadystatebackpass +end + +function fix_endpoints(sensealg, sol, ts) + @warn "Endpoints do not match. Return code: $(sol.retcode). Likely your time range is not a multiple of `saveat`. sol.t[end]: $(sol.t[end]), ts[end]: $(ts[end])" + ts = collect(ts) + push!(ts, sol.t[end]) +end diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index f85e5452b..a5b19dc6a 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -1,172 +1,177 @@ # Not in FiniteDiff because `u` -> scalar isn't used anywhere else, # but could be upstreamed. -mutable struct UGradientWrapper{fType,tType,P} <: Function - f::fType - t::tType - p::P +mutable struct UGradientWrapper{fType, tType, P} <: Function + f::fType + t::tType + p::P end -(ff::UGradientWrapper)(uprev) = ff.f(uprev,ff.p,ff.t) +(ff::UGradientWrapper)(uprev) = ff.f(uprev, ff.p, ff.t) -mutable struct ParamGradientWrapper{fType,tType,uType} <: Function - f::fType - t::tType - u::uType +mutable struct ParamGradientWrapper{fType, tType, uType} <: Function + f::fType + t::tType + u::uType end -(ff::ParamGradientWrapper)(p) = ff.f(ff.u,p,ff.t) +(ff::ParamGradientWrapper)(p) = ff.f(ff.u, p, ff.t) # the next four definitions are only needed in case of non-diagonal SDEs -mutable struct ParamNonDiagNoiseGradientWrapper{fType,tType,uType} <: Function - f::fType - t::tType - u::uType +mutable struct ParamNonDiagNoiseGradientWrapper{fType, tType, uType} <: Function + f::fType + t::tType + u::uType end -(ff::ParamNonDiagNoiseGradientWrapper)(p) = vec(ff.f(ff.u,p,ff.t)) +(ff::ParamNonDiagNoiseGradientWrapper)(p) = vec(ff.f(ff.u, p, ff.t)) -mutable struct ParamNonDiagNoiseJacobianWrapper{fType,tType,uType,duType} <: Function - f::fType - t::tType - u::uType - du::duType +mutable struct ParamNonDiagNoiseJacobianWrapper{fType, tType, uType, duType} <: Function + f::fType + t::tType + u::uType + du::duType end function (ff::ParamNonDiagNoiseJacobianWrapper)(p) - du1 = similar(p, size(ff.du)) - ff.f(du1,ff.u,p,ff.t) - return vec(du1) + du1 = similar(p, size(ff.du)) + ff.f(du1, ff.u, p, ff.t) + return vec(du1) end -function (ff::ParamNonDiagNoiseJacobianWrapper)(du1,p) - ff.f(du1,ff.u,p,ff.t) - return vec(du1) +function (ff::ParamNonDiagNoiseJacobianWrapper)(du1, p) + ff.f(du1, ff.u, p, ff.t) + return vec(du1) end -mutable struct UNonDiagNoiseGradientWrapper{fType,tType,P} <: Function - f::fType - t::tType - p::P +mutable struct UNonDiagNoiseGradientWrapper{fType, tType, P} <: Function + f::fType + t::tType + p::P end -(ff::UNonDiagNoiseGradientWrapper)(uprev) = vec(ff.f(uprev,ff.p,ff.t)) +(ff::UNonDiagNoiseGradientWrapper)(uprev) = vec(ff.f(uprev, ff.p, ff.t)) -mutable struct UNonDiagNoiseJacobianWrapper{fType,tType,P,duType} <: Function - f::fType - t::tType - p::P - du::duType +mutable struct UNonDiagNoiseJacobianWrapper{fType, tType, P, duType} <: Function + f::fType + t::tType + p::P + du::duType end -(ff::UNonDiagNoiseJacobianWrapper)(uprev) = (du1 = similar(ff.du); ff.f(du1,uprev,ff.p,ff.t); vec(du1)) +function (ff::UNonDiagNoiseJacobianWrapper)(uprev) + (du1 = similar(ff.du); ff.f(du1, uprev, ff.p, ff.t); vec(du1)) +end -function (ff::UNonDiagNoiseJacobianWrapper)(du1,uprev) - ff.f(du1,uprev,ff.p,ff.t) - return vec(du1) +function (ff::UNonDiagNoiseJacobianWrapper)(du1, uprev) + ff.f(du1, uprev, ff.p, ff.t) + return vec(du1) end # RODE wrappers -mutable struct RODEUJacobianWrapper{fType,tType,P,WType} <: Function - f::fType - t::tType - p::P - W::WType +mutable struct RODEUJacobianWrapper{fType, tType, P, WType} <: Function + f::fType + t::tType + p::P + W::WType end -(ff::RODEUJacobianWrapper)(du1,uprev) = ff.f(du1,uprev,ff.p,ff.t,ff.W) -(ff::RODEUJacobianWrapper)(uprev) = (du1 = similar(uprev); ff.f(du1,uprev,ff.p,ff.t,ff.W); du1) +(ff::RODEUJacobianWrapper)(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t, ff.W) +function (ff::RODEUJacobianWrapper)(uprev) + (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t, ff.W); du1) +end -mutable struct RODEUDerivativeWrapper{F,tType,P,WType} <: Function - f::F - t::tType - p::P - W::WType +mutable struct RODEUDerivativeWrapper{F, tType, P, WType} <: Function + f::F + t::tType + p::P + W::WType end -(ff::RODEUDerivativeWrapper)(u) = ff.f(u,ff.p,ff.t,ff.W) +(ff::RODEUDerivativeWrapper)(u) = ff.f(u, ff.p, ff.t, ff.W) -mutable struct RODEUGradientWrapper{fType,tType,P,WType} <: Function - f::fType - t::tType - p::P - W::WType +mutable struct RODEUGradientWrapper{fType, tType, P, WType} <: Function + f::fType + t::tType + p::P + W::WType end -(ff::RODEUGradientWrapper)(uprev) = ff.f(uprev,ff.p,ff.t,ff.W) +(ff::RODEUGradientWrapper)(uprev) = ff.f(uprev, ff.p, ff.t, ff.W) -mutable struct RODEParamGradientWrapper{fType,tType,uType,WType} <: Function - f::fType - t::tType - u::uType - W::WType +mutable struct RODEParamGradientWrapper{fType, tType, uType, WType} <: Function + f::fType + t::tType + u::uType + W::WType end -(ff::RODEParamGradientWrapper)(p) = ff.f(ff.u,p,ff.t,ff.W) +(ff::RODEParamGradientWrapper)(p) = ff.f(ff.u, p, ff.t, ff.W) -mutable struct RODEParamJacobianWrapper{fType,tType,uType,WType} <: Function - f::fType - t::tType - u::uType - W::WType +mutable struct RODEParamJacobianWrapper{fType, tType, uType, WType} <: Function + f::fType + t::tType + u::uType + W::WType end -(ff::RODEParamJacobianWrapper)(du1,p) = ff.f(du1,ff.u,p,ff.t,ff.W) +(ff::RODEParamJacobianWrapper)(du1, p) = ff.f(du1, ff.u, p, ff.t, ff.W) function (ff::RODEParamJacobianWrapper)(p) - du1 = similar(p, size(ff.u)) - ff.f(du1,ff.u,p,ff.t,ff.W) - return du1 + du1 = similar(p, size(ff.u)) + ff.f(du1, ff.u, p, ff.t, ff.W) + return du1 end -Base.@pure function determine_chunksize(u,alg::DiffEqBase.AbstractSensitivityAlgorithm) - determine_chunksize(u,get_chunksize(alg)) +Base.@pure function determine_chunksize(u, alg::DiffEqBase.AbstractSensitivityAlgorithm) + determine_chunksize(u, get_chunksize(alg)) end -Base.@pure function determine_chunksize(u,CS) - if CS != 0 - return CS - else - return ForwardDiff.pickchunksize(length(u)) - end +Base.@pure function determine_chunksize(u, CS) + if CS != 0 + return CS + else + return ForwardDiff.pickchunksize(length(u)) + end end -function jacobian(f, x::AbstractArray{<:Number}, alg::DiffEqBase.AbstractSensitivityAlgorithm) - if alg_autodiff(alg) - J = ForwardDiff.jacobian(f, x) - else - J = FiniteDiff.finite_difference_jacobian(f, x) - end - return J +function jacobian(f, x::AbstractArray{<:Number}, + alg::DiffEqBase.AbstractSensitivityAlgorithm) + if alg_autodiff(alg) + J = ForwardDiff.jacobian(f, x) + else + J = FiniteDiff.finite_difference_jacobian(f, x) + end + return J end - function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}, - fx::Union{Nothing,AbstractArray{<:Number}}, alg::DiffEqBase.AbstractSensitivityAlgorithm, jac_config) - if alg_autodiff(alg) - if fx === nothing - ForwardDiff.jacobian!(J, f, x) + fx::Union{Nothing, AbstractArray{<:Number}}, + alg::DiffEqBase.AbstractSensitivityAlgorithm, jac_config) + if alg_autodiff(alg) + if fx === nothing + ForwardDiff.jacobian!(J, f, x) + else + ForwardDiff.jacobian!(J, f, fx, x, jac_config) + end else - ForwardDiff.jacobian!(J, f, fx, x, jac_config) + FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config) end - else - FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config) - end - nothing + nothing end function derivative!(df::AbstractArray{<:Number}, f, - x::Number, - alg::DiffEqBase.AbstractSensitivityAlgorithm, der_config) + x::Number, + alg::DiffEqBase.AbstractSensitivityAlgorithm, der_config) if alg_autodiff(alg) - ForwardDiff.derivative!(df, f, x, ) # der_config doesn't work + ForwardDiff.derivative!(df, f, x) # der_config doesn't work else - FiniteDiff.finite_difference_derivative!(df, f, x, der_config) + FiniteDiff.finite_difference_derivative!(df, f, x, der_config) end nothing end function gradient!(df::AbstractArray{<:Number}, f, - x::Union{Number,AbstractArray{<:Number}}, + x::Union{Number, AbstractArray{<:Number}}, alg::DiffEqBase.AbstractSensitivityAlgorithm, grad_config) if alg_autodiff(alg) ForwardDiff.gradient!(df, f, x, grad_config) @@ -183,723 +188,742 @@ end """ function jacobianvec!(Jv::AbstractArray{<:Number}, f, x::AbstractArray{<:Number}, v, alg::DiffEqBase.AbstractSensitivityAlgorithm, config) - if alg_autodiff(alg) - buffer, seed = config - TD = typeof(first(seed)) - T = typeof(first(seed).partials) - DiffEqBase.@.. seed = TD(x, T(tuple(v))) - f(buffer, seed) - Jv .= ForwardDiff.partials.(buffer, 1) - else - buffer1, buffer2 = config - f(buffer1,x) - T = eltype(x) - # Should it be min? max? mean? - ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x))) - @. x += ϵ*v - f(buffer2,x) - @. x -= ϵ*v - @. Jv = (buffer2 - buffer1)/ϵ - end - nothing + if alg_autodiff(alg) + buffer, seed = config + TD = typeof(first(seed)) + T = typeof(first(seed).partials) + DiffEqBase.@.. seed = TD(x, T(tuple(v))) + f(buffer, seed) + Jv .= ForwardDiff.partials.(buffer, 1) + else + buffer1, buffer2 = config + f(buffer1, x) + T = eltype(x) + # Should it be min? max? mean? + ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x))) + @. x += ϵ * v + f(buffer2, x) + @. x -= ϵ * v + @. Jv = (buffer2 - buffer1) / ϵ + end + nothing end function jacobianmat!(JM::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}, M, alg::DiffEqBase.AbstractSensitivityAlgorithm, config) - buffer, seed = config - T = eltype(seed) - numparams = length(ForwardDiff.partials(seed[1])) - for i in eachindex(seed) - seed[i] = T(x[i],ForwardDiff.Partials(ntuple(j -> M[i,j], numparams))) - end - f(buffer,seed) - for (j,dual) in enumerate(buffer) - for (i,partial) in enumerate(ForwardDiff.partials(dual)) - JM[j,i] = partial + buffer, seed = config + T = eltype(seed) + numparams = length(ForwardDiff.partials(seed[1])) + for i in eachindex(seed) + seed[i] = T(x[i], ForwardDiff.Partials(ntuple(j -> M[i, j], numparams))) end - end - return nothing + f(buffer, seed) + for (j, dual) in enumerate(buffer) + for (i, partial) in enumerate(ForwardDiff.partials(dual)) + JM[j, i] = partial + end + end + return nothing end function vecjacobian!(dλ, y, λ, p, t, S::TS; - dgrad=nothing, dy=nothing, W=nothing) where TS<:SensitivityFunction - _vecjacobian!(dλ, y, λ, p, t, S, S.sensealg.autojacvec, dgrad, dy, W) - return -end - -function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy, W) where TS<:SensitivityFunction - @unpack sensealg, f = S - prob = getprob(S) - - @unpack J, uf, f_cache, jac_config = S.diffcache - if !(prob isa DiffEqBase.SteadyStateProblem) - if W===nothing - if DiffEqBase.has_jac(f) - f.jac(J,y,p,t) # Calculate the Jacobian into J - else - uf.t = t - uf.p = p - jacobian!(J, uf, y, f_cache, sensealg, jac_config) - end - else - if DiffEqBase.has_jac(f) - f.jac(J,y,p,t,W) # Calculate the Jacobian into J - else - uf.t = t - uf.p = p - uf.W = W - jacobian!(J, uf, y, f_cache, sensealg, jac_config) - end - end - mul!(dλ',λ',J) - end - if dgrad !== nothing - @unpack pJ, pf, paramjac_config = S.diffcache - if W===nothing - if DiffEqBase.has_paramjac(f) - # Calculate the parameter Jacobian into pJ - f.paramjac(pJ,y,p,t) - else - pf.t = t - pf.u = y - if inplace_sensitivity(S) - jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config) + dgrad = nothing, dy = nothing, + W = nothing) where {TS <: SensitivityFunction} + _vecjacobian!(dλ, y, λ, p, t, S, S.sensealg.autojacvec, dgrad, dy, W) + return +end + +function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy, + W) where {TS <: SensitivityFunction} + @unpack sensealg, f = S + prob = getprob(S) + + @unpack J, uf, f_cache, jac_config = S.diffcache + if !(prob isa DiffEqBase.SteadyStateProblem) + if W === nothing + if DiffEqBase.has_jac(f) + f.jac(J, y, p, t) # Calculate the Jacobian into J + else + uf.t = t + uf.p = p + jacobian!(J, uf, y, f_cache, sensealg, jac_config) + end else - temp = jacobian(pf, p, sensealg) - pJ .= temp + if DiffEqBase.has_jac(f) + f.jac(J, y, p, t, W) # Calculate the Jacobian into J + else + uf.t = t + uf.p = p + uf.W = W + jacobian!(J, uf, y, f_cache, sensealg, jac_config) + end end - end - else - if DiffEqBase.has_paramjac(f) - # Calculate the parameter Jacobian into pJ - f.paramjac(pJ,y,p,t,W) - else - pf.t = t - pf.u = y - pf.W = W - if inplace_sensitivity(S) - jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config) + mul!(dλ', λ', J) + end + if dgrad !== nothing + @unpack pJ, pf, paramjac_config = S.diffcache + if W === nothing + if DiffEqBase.has_paramjac(f) + # Calculate the parameter Jacobian into pJ + f.paramjac(pJ, y, p, t) + else + pf.t = t + pf.u = y + if inplace_sensitivity(S) + jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config) + else + temp = jacobian(pf, p, sensealg) + pJ .= temp + end + end else - temp = jacobian(pf, p, sensealg) - pJ .= temp + if DiffEqBase.has_paramjac(f) + # Calculate the parameter Jacobian into pJ + f.paramjac(pJ, y, p, t, W) + else + pf.t = t + pf.u = y + pf.W = W + if inplace_sensitivity(S) + jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config) + else + temp = jacobian(pf, p, sensealg) + pJ .= temp + end + end end - end + mul!(dgrad', λ', pJ) end - mul!(dgrad',λ',pJ) - end - if dy !== nothing - if W===nothing - if inplace_sensitivity(S) - f(dy, y, p, t) - else - dy[:] .= vec(f(y, p, t)) - end - else - if inplace_sensitivity(S) - f(dy, y, p, t, W) - else - dy[:] .= vec(f(y, p, t, W)) - end + if dy !== nothing + if W === nothing + if inplace_sensitivity(S) + f(dy, y, p, t) + else + dy[:] .= vec(f(y, p, t)) + end + else + if inplace_sensitivity(S) + f(dy, y, p, t, W) + else + dy[:] .= vec(f(y, p, t, W)) + end + end end - end - return + return end -const TRACKERVJP_NOTHING_MESSAGE = -""" -`nothing` returned from a Tracker vector-Jacobian product (vjp) calculation. -This indicates that your function `f` is not a function of `p` or `u`, i.e. that -the derivative is constant zero. In many cases this is due to an error in -the model definition, for example accidentally using a global parameter -instead of the one in the model (`f(u,p,t)= _p .* u`). - -One common cause of this is using Flux neural networks with implicit parameters, -for example `f(u,p,t) = NN(u)` does not use `p` and therefore will have a zero -derivative. The answer is to use `Flux.destructure` in this case, for example: - -```julia -p,re = Flux.destructure(NN) -f(u,p,t) = re(p)(u) -prob = ODEProblem(f,u0,tspan,p) -``` - -Note that restructuring outside of `f`, i.e. `reNN = re(p); f(u,p,t) = reNN(u)` will -also trigger a zero gradient. The `p` must be used inside of `f`, not globally outside. - -If this zero gradient with respect to `u` or `p` is intended, then one can set -`TrackerVJP(allow_nothing=true)` to override this error message. For example: - -```julia -solve(prob,alg,sensealg=InterpolatingAdjoint(autojacvec=TrackerVJP(allow_nothing=true))) -``` -""" +const TRACKERVJP_NOTHING_MESSAGE = """ + `nothing` returned from a Tracker vector-Jacobian product (vjp) calculation. + This indicates that your function `f` is not a function of `p` or `u`, i.e. that + the derivative is constant zero. In many cases this is due to an error in + the model definition, for example accidentally using a global parameter + instead of the one in the model (`f(u,p,t)= _p .* u`). + + One common cause of this is using Flux neural networks with implicit parameters, + for example `f(u,p,t) = NN(u)` does not use `p` and therefore will have a zero + derivative. The answer is to use `Flux.destructure` in this case, for example: + + ```julia + p,re = Flux.destructure(NN) + f(u,p,t) = re(p)(u) + prob = ODEProblem(f,u0,tspan,p) + ``` + + Note that restructuring outside of `f`, i.e. `reNN = re(p); f(u,p,t) = reNN(u)` will + also trigger a zero gradient. The `p` must be used inside of `f`, not globally outside. + + If this zero gradient with respect to `u` or `p` is intended, then one can set + `TrackerVJP(allow_nothing=true)` to override this error message. For example: + + ```julia + solve(prob,alg,sensealg=InterpolatingAdjoint(autojacvec=TrackerVJP(allow_nothing=true))) + ``` + """ struct TrackerVJPNothingError <: Exception end function Base.showerror(io::IO, e::TrackerVJPNothingError) - print(io, TRACKERVJP_NOTHING_MESSAGE) -end - -function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::TrackerVJP, dgrad, dy, W) where TS<:SensitivityFunction - @unpack sensealg, f = S - isautojacvec = get_jacvec(sensealg) - if inplace_sensitivity(S) - if W===nothing - _dy, back = Tracker.forward(y, p) do u, p - out_ = map(zero, u) - f(out_, u, p, t) - Tracker.collect(out_) - end - else - _dy, back = Tracker.forward(y, p) do u, p - out_ = map(zero, u) - f(out_, u, p, t, W) - Tracker.collect(out_) - end - end + print(io, TRACKERVJP_NOTHING_MESSAGE) +end - if !(typeof(_dy) isa TrackedArray) && !(eltype(_dy) <: Tracker.TrackedReal) && - !sensealg.autojacvec.allow_nothing - throw(TrackerVJPNothingError()) - end +function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::TrackerVJP, dgrad, dy, + W) where {TS <: SensitivityFunction} + @unpack sensealg, f = S + isautojacvec = get_jacvec(sensealg) + if inplace_sensitivity(S) + if W === nothing + _dy, back = Tracker.forward(y, p) do u, p + out_ = map(zero, u) + f(out_, u, p, t) + Tracker.collect(out_) + end + else + _dy, back = Tracker.forward(y, p) do u, p + out_ = map(zero, u) + f(out_, u, p, t, W) + Tracker.collect(out_) + end + end + if !(typeof(_dy) isa TrackedArray) && !(eltype(_dy) <: Tracker.TrackedReal) && + !sensealg.autojacvec.allow_nothing + throw(TrackerVJPNothingError()) + end - # Grab values from `_dy` before `back` in case mutated - dy !== nothing && (dy[:] .= vec(Tracker.data(_dy))) + # Grab values from `_dy` before `back` in case mutated + dy !== nothing && (dy[:] .= vec(Tracker.data(_dy))) - tmp1, tmp2 = Tracker.data.(back(λ)) - dλ[:] .= vec(tmp1) - dgrad !== nothing && (dgrad[:] .= vec(tmp2)) - else - if W===nothing - _dy, back = Tracker.forward(y, p) do u, p - Tracker.collect(f(u, p, t)) - end + tmp1, tmp2 = Tracker.data.(back(λ)) + dλ[:] .= vec(tmp1) + dgrad !== nothing && (dgrad[:] .= vec(tmp2)) else - _dy, back = Tracker.forward(y, p) do u, p - Tracker.collect(f(u, p, t, W)) - end - end + if W === nothing + _dy, back = Tracker.forward(y, p) do u, p + Tracker.collect(f(u, p, t)) + end + else + _dy, back = Tracker.forward(y, p) do u, p + Tracker.collect(f(u, p, t, W)) + end + end - if !(typeof(_dy) isa TrackedArray) && !(eltype(_dy) <: Tracker.TrackedReal) && - !sensealg.autojacvec.allow_nothing - throw(TrackerVJPNothingError()) + if !(typeof(_dy) isa TrackedArray) && !(eltype(_dy) <: Tracker.TrackedReal) && + !sensealg.autojacvec.allow_nothing + throw(TrackerVJPNothingError()) + end + + # Grab values from `_dy` before `back` in case mutated + dy !== nothing && (dy[:] .= vec(Tracker.data(_dy))) + + tmp1, tmp2 = Tracker.data.(back(λ)) + dλ[:] .= vec(tmp1) + dgrad !== nothing && (dgrad[:] .= vec(tmp2)) end + return +end + +function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dgrad, dy, + W) where {TS <: SensitivityFunction} + @unpack sensealg, f = S + prob = getprob(S) + isautojacvec = get_jacvec(sensealg) - # Grab values from `_dy` before `back` in case mutated - dy !== nothing && (dy[:] .= vec(Tracker.data(_dy))) - - tmp1, tmp2 = Tracker.data.(back(λ)) - dλ[:] .= vec(tmp1) - dgrad !== nothing && (dgrad[:] .= vec(tmp2)) - end - return -end - -function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dgrad, dy, W) where TS<:SensitivityFunction - @unpack sensealg, f = S - prob = getprob(S) - isautojacvec = get_jacvec(sensealg) - - if typeof(p) <: DiffEqBase.NullParameters - _p = similar(y,(0,)) - else - _p = p - end - - if typeof(prob) <: SteadyStateProblem || (eltype(λ) <: eltype(prob.u0) && typeof(t) <: eltype(prob.u0) && compile_tape(sensealg.autojacvec)) - tape = S.diffcache.paramjac_config - - ## These other cases happen due to autodiff in stiff ODE solvers - elseif inplace_sensitivity(S) - _y = eltype(y) === eltype(λ) ? y : convert.(promote_type(eltype(y),eltype(λ)),y) - if W===nothing - tape = ReverseDiff.GradientTape((_y, _p, [t])) do u,p,t - du1 = similar(u, size(u)) - f(du1,u,p,first(t)) - return vec(du1) - end + if typeof(p) <: DiffEqBase.NullParameters + _p = similar(y, (0,)) else - _W = eltype(W) === eltype(λ) ? W : convert.(promote_type(eltype(W),eltype(λ)),W) - tape = ReverseDiff.GradientTape((_y, _p, [t], _W)) do u,p,t,Wloc - du1 = p !== nothing && p !== DiffEqBase.NullParameters() ? similar(p, size(u)) : similar(u) - f(du1,u,p,first(t),Wloc) - return vec(du1) - end - end - else - _y = eltype(y) === eltype(λ) ? y : convert.(promote_type(eltype(y),eltype(λ)),y) - if W===nothing - tape = ReverseDiff.GradientTape((_y, _p, [t])) do u,p,t - vec(f(u,p,first(t))) - end + _p = p + end + + if typeof(prob) <: SteadyStateProblem || + (eltype(λ) <: eltype(prob.u0) && typeof(t) <: eltype(prob.u0) && + compile_tape(sensealg.autojacvec)) + tape = S.diffcache.paramjac_config + + ## These other cases happen due to autodiff in stiff ODE solvers + elseif inplace_sensitivity(S) + _y = eltype(y) === eltype(λ) ? y : convert.(promote_type(eltype(y), eltype(λ)), y) + if W === nothing + tape = ReverseDiff.GradientTape((_y, _p, [t])) do u, p, t + du1 = similar(u, size(u)) + f(du1, u, p, first(t)) + return vec(du1) + end + else + _W = eltype(W) === eltype(λ) ? W : + convert.(promote_type(eltype(W), eltype(λ)), W) + tape = ReverseDiff.GradientTape((_y, _p, [t], _W)) do u, p, t, Wloc + du1 = p !== nothing && p !== DiffEqBase.NullParameters() ? + similar(p, size(u)) : similar(u) + f(du1, u, p, first(t), Wloc) + return vec(du1) + end + end else - _W = eltype(W) === eltype(λ) ? W : convert.(promote_type(eltype(W),eltype(λ)),W) - tape = ReverseDiff.GradientTape((_y, _p, [t], _W)) do u,p,t,Wloc - vec(f(u,p,first(t),Wloc)) - end + _y = eltype(y) === eltype(λ) ? y : convert.(promote_type(eltype(y), eltype(λ)), y) + if W === nothing + tape = ReverseDiff.GradientTape((_y, _p, [t])) do u, p, t + vec(f(u, p, first(t))) + end + else + _W = eltype(W) === eltype(λ) ? W : + convert.(promote_type(eltype(W), eltype(λ)), W) + tape = ReverseDiff.GradientTape((_y, _p, [t], _W)) do u, p, t, Wloc + vec(f(u, p, first(t), Wloc)) + end + end end - end - if prob isa DiffEqBase.SteadyStateProblem - tu, tp = ReverseDiff.input_hook(tape) - else - if W===nothing - tu, tp, tt = ReverseDiff.input_hook(tape) + if prob isa DiffEqBase.SteadyStateProblem + tu, tp = ReverseDiff.input_hook(tape) else - tu, tp, tt, tW = ReverseDiff.input_hook(tape) + if W === nothing + tu, tp, tt = ReverseDiff.input_hook(tape) + else + tu, tp, tt, tW = ReverseDiff.input_hook(tape) + end end - end - output = ReverseDiff.output_hook(tape) - ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls - ReverseDiff.unseed!(tp) - if !(prob isa DiffEqBase.SteadyStateProblem) - ReverseDiff.unseed!(tt) - end - W !== nothing && ReverseDiff.unseed!(tW) - ReverseDiff.value!(tu, y) - typeof(p) <: DiffEqBase.NullParameters || ReverseDiff.value!(tp, p) - if !(prob isa DiffEqBase.SteadyStateProblem) - ReverseDiff.value!(tt, [t]) - end - W !== nothing && ReverseDiff.value!(tW, W) - ReverseDiff.forward_pass!(tape) - ReverseDiff.increment_deriv!(output, λ) - ReverseDiff.reverse_pass!(tape) - copyto!(vec(dλ), ReverseDiff.deriv(tu)) - dgrad !== nothing && copyto!(vec(dgrad), ReverseDiff.deriv(tp)) - ReverseDiff.pull_value!(output) - dy !== nothing && copyto!(vec(dy), ReverseDiff.value(output)) - return -end - -const ZYGOTEVJP_NOTHING_MESSAGE = -""" -`nothing` returned from a Zygote vector-Jacobian product (vjp) calculation. -This indicates that your function `f` is not a function of `p` or `u`, i.e. that -the derivative is constant zero. In many cases this is due to an error in -the model definition, for example accidentally using a global parameter -instead of the one in the model (`f(u,p,t)= _p .* u`). - -One common cause of this is using Flux neural networks with implicit parameters, -for example `f(u,p,t) = NN(u)` does not use `p` and therefore will have a zero -derivative. The answer is to use `Flux.destructure` in this case, for example: - -```julia -p,re = Flux.destructure(NN) -f(u,p,t) = re(p)(u) -prob = ODEProblem(f,u0,tspan,p) -``` - -Note that restructuring outside of `f`, i.e. `reNN = re(p); f(u,p,t) = reNN(u)` will -also trigger a zero gradient. The `p` must be used inside of `f`, not globally outside. - -If this zero gradient with respect to `u` or `p` is intended, then one can set -`ZygoteVJP(allow_nothing=true)` to override this error message, for example: - -```julia -solve(prob,alg,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP(allow_nothing=true))) -``` -""" + output = ReverseDiff.output_hook(tape) + ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls + ReverseDiff.unseed!(tp) + if !(prob isa DiffEqBase.SteadyStateProblem) + ReverseDiff.unseed!(tt) + end + W !== nothing && ReverseDiff.unseed!(tW) + ReverseDiff.value!(tu, y) + typeof(p) <: DiffEqBase.NullParameters || ReverseDiff.value!(tp, p) + if !(prob isa DiffEqBase.SteadyStateProblem) + ReverseDiff.value!(tt, [t]) + end + W !== nothing && ReverseDiff.value!(tW, W) + ReverseDiff.forward_pass!(tape) + ReverseDiff.increment_deriv!(output, λ) + ReverseDiff.reverse_pass!(tape) + copyto!(vec(dλ), ReverseDiff.deriv(tu)) + dgrad !== nothing && copyto!(vec(dgrad), ReverseDiff.deriv(tp)) + ReverseDiff.pull_value!(output) + dy !== nothing && copyto!(vec(dy), ReverseDiff.value(output)) + return +end + +const ZYGOTEVJP_NOTHING_MESSAGE = """ + `nothing` returned from a Zygote vector-Jacobian product (vjp) calculation. + This indicates that your function `f` is not a function of `p` or `u`, i.e. that + the derivative is constant zero. In many cases this is due to an error in + the model definition, for example accidentally using a global parameter + instead of the one in the model (`f(u,p,t)= _p .* u`). + + One common cause of this is using Flux neural networks with implicit parameters, + for example `f(u,p,t) = NN(u)` does not use `p` and therefore will have a zero + derivative. The answer is to use `Flux.destructure` in this case, for example: + + ```julia + p,re = Flux.destructure(NN) + f(u,p,t) = re(p)(u) + prob = ODEProblem(f,u0,tspan,p) + ``` + + Note that restructuring outside of `f`, i.e. `reNN = re(p); f(u,p,t) = reNN(u)` will + also trigger a zero gradient. The `p` must be used inside of `f`, not globally outside. + + If this zero gradient with respect to `u` or `p` is intended, then one can set + `ZygoteVJP(allow_nothing=true)` to override this error message, for example: + + ```julia + solve(prob,alg,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP(allow_nothing=true))) + ``` + """ struct ZygoteVJPNothingError <: Exception end function Base.showerror(io::IO, e::ZygoteVJPNothingError) - print(io, ZYGOTEVJP_NOTHING_MESSAGE) + print(io, ZYGOTEVJP_NOTHING_MESSAGE) end -function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, W) where TS<:SensitivityFunction - @unpack sensealg, f = S - prob = getprob(S) +function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, + W) where {TS <: SensitivityFunction} + @unpack sensealg, f = S + prob = getprob(S) - isautojacvec = get_jacvec(sensealg) - if inplace_sensitivity(S) - if W===nothing - _dy, back = Zygote.pullback(y, p) do u, p - out_ = Zygote.Buffer(similar(u)) - f(out_, u, p, t) - vec(copy(out_)) - end - else - _dy, back = Zygote.pullback(y, p) do u, p - out_ = Zygote.Buffer(similar(u)) - f(out_, u, p, t, W) - vec(copy(out_)) - end - end + isautojacvec = get_jacvec(sensealg) + if inplace_sensitivity(S) + if W === nothing + _dy, back = Zygote.pullback(y, p) do u, p + out_ = Zygote.Buffer(similar(u)) + f(out_, u, p, t) + vec(copy(out_)) + end + else + _dy, back = Zygote.pullback(y, p) do u, p + out_ = Zygote.Buffer(similar(u)) + f(out_, u, p, t, W) + vec(copy(out_)) + end + end - # Grab values from `_dy` before `back` in case mutated - dy !== nothing && (dy[:] .= vec(_dy)) - - tmp1,tmp2 = back(λ) - dλ[:] .= vec(tmp1) - if dgrad !== nothing - if tmp2 === nothing && !sensealg.autojacvec.allow_nothing - throw(ZygoteVJPNothingError()) - else - (dgrad[:] .= vec(tmp2)) - end - end - else - if W===nothing - _dy, back = Zygote.pullback(y, p) do u, p - vec(f(u, p, t)) - end + # Grab values from `_dy` before `back` in case mutated + dy !== nothing && (dy[:] .= vec(_dy)) + + tmp1, tmp2 = back(λ) + dλ[:] .= vec(tmp1) + if dgrad !== nothing + if tmp2 === nothing && !sensealg.autojacvec.allow_nothing + throw(ZygoteVJPNothingError()) + else + (dgrad[:] .= vec(tmp2)) + end + end else - _dy, back = Zygote.pullback(y, p) do u, p - vec(f(u, p, t, W)) - end - end + if W === nothing + _dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t)) + end + else + _dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t, W)) + end + end - # Grab values from `_dy` before `back` in case mutated - dy !== nothing && (dy[:] .= vec(_dy)) + # Grab values from `_dy` before `back` in case mutated + dy !== nothing && (dy[:] .= vec(_dy)) - tmp1, tmp2 = back(λ) - if tmp1 === nothing && !sensealg.autojacvec.allow_nothing - throw(ZygoteVJPNothingError()) - elseif tmp1 !== nothing - (dλ[:] .= vec(tmp1)) - end + tmp1, tmp2 = back(λ) + if tmp1 === nothing && !sensealg.autojacvec.allow_nothing + throw(ZygoteVJPNothingError()) + elseif tmp1 !== nothing + (dλ[:] .= vec(tmp1)) + end - if dgrad !== nothing - if tmp2 === nothing && !sensealg.autojacvec.allow_nothing - throw(ZygoteVJPNothingError()) - elseif tmp2 !== nothing - (dgrad[:] .= vec(tmp2)) - end + if dgrad !== nothing + if tmp2 === nothing && !sensealg.autojacvec.allow_nothing + throw(ZygoteVJPNothingError()) + elseif tmp2 !== nothing + (dgrad[:] .= vec(tmp2)) + end + end end - end - return + return end -function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dy, W) where TS<:SensitivityFunction - @unpack sensealg = S - f = S.f.f +function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dy, + W) where {TS <: SensitivityFunction} + @unpack sensealg = S + f = S.f.f - prob = getprob(S) + prob = getprob(S) - tmp1,tmp2,tmp3,tmp4 = S.diffcache.paramjac_config + tmp1, tmp2, tmp3, tmp4 = S.diffcache.paramjac_config - tmp1 .= 0 # should be removed for dλ + tmp1 .= 0 # should be removed for dλ - #if dgrad !== nothing + #if dgrad !== nothing # tmp2 = dgrad - #else - dup = if !(typeof(tmp2) <: DiffEqBase.NullParameters) - tmp2 .= 0 - Enzyme.Duplicated(p, tmp2) - else - p - end - #end - - #if dy !== nothing - # tmp3 = dy - #else - tmp3 .= 0 - #end - - vec(tmp4) .= vec(λ) - - isautojacvec = get_jacvec(sensealg) - if inplace_sensitivity(S) - if W===nothing - Enzyme.autodiff(S.diffcache.pf,Enzyme.Duplicated(tmp3, tmp4), - Enzyme.Duplicated(y, tmp1), - dup, - t) + #else + dup = if !(typeof(tmp2) <: DiffEqBase.NullParameters) + tmp2 .= 0 + Enzyme.Duplicated(p, tmp2) else - Enzyme.autodiff(S.diffcache.pf,Enzyme.Duplicated(tmp3, tmp4), - Enzyme.Duplicated(y, tmp1), - dup, - t,W) + p end + #end + + #if dy !== nothing + # tmp3 = dy + #else + tmp3 .= 0 + #end + + vec(tmp4) .= vec(λ) - dλ .= tmp1 - dgrad !== nothing && (dgrad[:] .= vec(tmp2)) - dy !== nothing && (dy .= tmp3) - else - if W===nothing - Enzyme.autodiff(S.diffcache.pf,Enzyme.Duplicated(tmp3, tmp4), - Enzyme.Duplicated(y, tmp1), - dup,t) + isautojacvec = get_jacvec(sensealg) + if inplace_sensitivity(S) + if W === nothing + Enzyme.autodiff(S.diffcache.pf, Enzyme.Duplicated(tmp3, tmp4), + Enzyme.Duplicated(y, tmp1), + dup, + t) + else + Enzyme.autodiff(S.diffcache.pf, Enzyme.Duplicated(tmp3, tmp4), + Enzyme.Duplicated(y, tmp1), + dup, + t, W) + end + + dλ .= tmp1 + dgrad !== nothing && (dgrad[:] .= vec(tmp2)) + dy !== nothing && (dy .= tmp3) else - Enzyme.autodiff(S.diffcache.pf,Enzyme.Duplicated(tmp3, tmp4), - Enzyme.Duplicated(y, tmp1), - dup,t,W) - end - if dy !== nothing - out_ = if W===nothing - f(y, p, t) + if W === nothing + Enzyme.autodiff(S.diffcache.pf, Enzyme.Duplicated(tmp3, tmp4), + Enzyme.Duplicated(y, tmp1), + dup, t) else - f(y, p, t, W) + Enzyme.autodiff(S.diffcache.pf, Enzyme.Duplicated(tmp3, tmp4), + Enzyme.Duplicated(y, tmp1), + dup, t, W) end - dy[:] .= vec(out_) + if dy !== nothing + out_ = if W === nothing + f(y, p, t) + else + f(y, p, t, W) + end + dy[:] .= vec(out_) + end + dλ .= tmp1 + dgrad !== nothing && !(typeof(tmp2) <: DiffEqBase.NullParameters) && + (dgrad[:] .= vec(tmp2)) + dy !== nothing && (dy .= tmp3) end - dλ .= tmp1 - dgrad !== nothing && !(typeof(tmp2) <: DiffEqBase.NullParameters) && (dgrad[:] .= vec(tmp2)) - dy !== nothing && (dy .= tmp3) - end - return + return end function jacNoise!(λ, y, p, t, S::SensitivityFunction; - dgrad=nothing, dλ=nothing, dy=nothing) - _jacNoise!(λ, y, p, t, S, S.sensealg.autojacvec, dgrad, dλ, dy) - return + dgrad = nothing, dλ = nothing, dy = nothing) + _jacNoise!(λ, y, p, t, S, S.sensealg.autojacvec, dgrad, dλ, dy) + return end -function _jacNoise!(λ, y, p, t, S::TS, isnoise::Bool, dgrad, dλ, dy) where TS<:SensitivityFunction - @unpack sensealg, f = S - prob = getprob(S) +function _jacNoise!(λ, y, p, t, S::TS, isnoise::Bool, dgrad, dλ, + dy) where {TS <: SensitivityFunction} + @unpack sensealg, f = S + prob = getprob(S) - if dgrad !== nothing - @unpack pJ, pf, f_cache, paramjac_noise_config = S.diffcache - if DiffEqBase.has_paramjac(f) - # Calculate the parameter Jacobian into pJ - f.paramjac(pJ,y,p,t) - else - pf.t = t - pf.u = y - if inplace_sensitivity(S) - jacobian!(pJ, pf, p, nothing, sensealg, nothing) - #jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_noise_config) - else - temp = jacobian(pf, p, sensealg) - pJ .= temp - end - end + if dgrad !== nothing + @unpack pJ, pf, f_cache, paramjac_noise_config = S.diffcache + if DiffEqBase.has_paramjac(f) + # Calculate the parameter Jacobian into pJ + f.paramjac(pJ, y, p, t) + else + pf.t = t + pf.u = y + if inplace_sensitivity(S) + jacobian!(pJ, pf, p, nothing, sensealg, nothing) + #jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_noise_config) + else + temp = jacobian(pf, p, sensealg) + pJ .= temp + end + end - if StochasticDiffEq.is_diagonal_noise(prob) - pJt = transpose(λ).*transpose(pJ) - dgrad[:] .= vec(pJt) - else - m = size(prob.noise_rate_prototype)[2] - for i in 1:m - tmp = λ'*pJ[(i-1)*m+1:i*m,:] - dgrad[:,i] .= vec(tmp) - end - end - end - - if dλ !== nothing && (isnoisemixing(sensealg) || !StochasticDiffEq.is_diagonal_noise(prob)) - @unpack J, uf, f_cache, jac_noise_config = S.diffcache - if dy!== nothing - if inplace_sensitivity(S) - f(dy, y, p, t) - else - dy .= f(y, p, t) - end + if StochasticDiffEq.is_diagonal_noise(prob) + pJt = transpose(λ) .* transpose(pJ) + dgrad[:] .= vec(pJt) + else + m = size(prob.noise_rate_prototype)[2] + for i in 1:m + tmp = λ' * pJ[((i - 1) * m + 1):(i * m), :] + dgrad[:, i] .= vec(tmp) + end + end end - if DiffEqBase.has_jac(f) - f.jac(J,y,p,t) # Calculate the Jacobian into J - else - if inplace_sensitivity(S) + if dλ !== nothing && + (isnoisemixing(sensealg) || !StochasticDiffEq.is_diagonal_noise(prob)) + @unpack J, uf, f_cache, jac_noise_config = S.diffcache if dy !== nothing - ForwardDiff.jacobian!(J,uf,dy,y) + if inplace_sensitivity(S) + f(dy, y, p, t) + else + dy .= f(y, p, t) + end + end + + if DiffEqBase.has_jac(f) + f.jac(J, y, p, t) # Calculate the Jacobian into J else - if StochasticDiffEq.is_diagonal_noise(prob) - dy = similar(y) - else - dy = similar(prob.noise_rate_prototype) - f(dy, y, p, t) - ForwardDiff.jacobian!(J,uf,dy,y) - end - f(dy, y, p, t) - ForwardDiff.jacobian!(J,uf,dy,y) - end - else - tmp = ForwardDiff.jacobian(uf,y) - J .= tmp - end - # uf.t = t - # uf.p = p - # jacobian!(J, uf, y, nothing, sensealg, nothing) - end + if inplace_sensitivity(S) + if dy !== nothing + ForwardDiff.jacobian!(J, uf, dy, y) + else + if StochasticDiffEq.is_diagonal_noise(prob) + dy = similar(y) + else + dy = similar(prob.noise_rate_prototype) + f(dy, y, p, t) + ForwardDiff.jacobian!(J, uf, dy, y) + end + f(dy, y, p, t) + ForwardDiff.jacobian!(J, uf, dy, y) + end + else + tmp = ForwardDiff.jacobian(uf, y) + J .= tmp + end + # uf.t = t + # uf.p = p + # jacobian!(J, uf, y, nothing, sensealg, nothing) + end - if StochasticDiffEq.is_diagonal_noise(prob) - Jt = transpose(λ).*transpose(J) - dλ[:] .= vec(Jt) - else - for i in 1:m - tmp = λ'*J[(i-1)*m+1:i*m,:] - dλ[:,i] .= vec(tmp) - end + if StochasticDiffEq.is_diagonal_noise(prob) + Jt = transpose(λ) .* transpose(J) + dλ[:] .= vec(Jt) + else + for i in 1:m + tmp = λ' * J[((i - 1) * m + 1):(i * m), :] + dλ[:, i] .= vec(tmp) + end + end end - end - return -end + return +end + +function _jacNoise!(λ, y, p, t, S::TS, isnoise::ReverseDiffVJP, dgrad, dλ, + dy) where {TS <: SensitivityFunction} + @unpack sensealg, f = S + prob = getprob(S) + + for (i, λi) in enumerate(λ) + tapei = S.diffcache.paramjac_noise_config[i] + tu, tp, tt = ReverseDiff.input_hook(tapei) + output = ReverseDiff.output_hook(tapei) + ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls + ReverseDiff.unseed!(tp) + ReverseDiff.unseed!(tt) + ReverseDiff.value!(tu, y) + ReverseDiff.value!(tp, p) + ReverseDiff.value!(tt, [t]) + ReverseDiff.forward_pass!(tapei) + if StochasticDiffEq.is_diagonal_noise(prob) + ReverseDiff.increment_deriv!(output, λi) + else + ReverseDiff.increment_deriv!(output, λ) + end + ReverseDiff.reverse_pass!(tapei) -function _jacNoise!(λ, y, p, t, S::TS, isnoise::ReverseDiffVJP, dgrad, dλ, dy) where TS<:SensitivityFunction - @unpack sensealg, f = S - prob = getprob(S) + deriv = ReverseDiff.deriv(tp) + dgrad[:, i] .= vec(deriv) + ReverseDiff.pull_value!(output) - for (i, λi) in enumerate(λ) - tapei = S.diffcache.paramjac_noise_config[i] - tu, tp, tt = ReverseDiff.input_hook(tapei) - output = ReverseDiff.output_hook(tapei) - ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls - ReverseDiff.unseed!(tp) - ReverseDiff.unseed!(tt) - ReverseDiff.value!(tu, y) - ReverseDiff.value!(tp, p) - ReverseDiff.value!(tt, [t]) - ReverseDiff.forward_pass!(tapei) - if StochasticDiffEq.is_diagonal_noise(prob) - ReverseDiff.increment_deriv!(output, λi) - else - ReverseDiff.increment_deriv!(output, λ) + if StochasticDiffEq.is_diagonal_noise(prob) + dλ !== nothing && (dλ[:, i] .= vec(ReverseDiff.deriv(tu))) + dy !== nothing && (dy[i] = ReverseDiff.value(output)) + else + dλ !== nothing && (dλ[:, i] .= vec(ReverseDiff.deriv(tu))) + dy !== nothing && (dy[:, i] .= vec(ReverseDiff.value(output))) + end end - ReverseDiff.reverse_pass!(tapei) + return +end - deriv = ReverseDiff.deriv(tp) - dgrad[:,i] .= vec(deriv) - ReverseDiff.pull_value!(output) +function _jacNoise!(λ, y, p, t, S::TS, isnoise::ZygoteVJP, dgrad, dλ, + dy) where {TS <: SensitivityFunction} + @unpack sensealg, f = S + prob = getprob(S) if StochasticDiffEq.is_diagonal_noise(prob) - dλ !== nothing && (dλ[:,i] .= vec(ReverseDiff.deriv(tu))) - dy !== nothing && (dy[i] = ReverseDiff.value(output)) + if inplace_sensitivity(S) + for (i, λi) in enumerate(λ) + _dy, back = Zygote.pullback(y, p) do u, p + out_ = Zygote.Buffer(similar(u)) + f(out_, u, p, t) + copy(out_[i]) + end + tmp1, tmp2 = back(λi) #issue: tmp2 = zeros(p) + dgrad[:, i] .= vec(tmp2) + dλ !== nothing && (dλ[:, i] .= vec(tmp1)) + dy !== nothing && (dy[i] = _dy) + end + else + for (i, λi) in enumerate(λ) + _dy, back = Zygote.pullback(y, p) do u, p + f(u, p, t)[i] + end + tmp1, tmp2 = back(λi) + dgrad[:, i] .= vec(tmp2) + dλ !== nothing && (dλ[:, i] .= vec(tmp1)) + dy !== nothing && (dy[i] = _dy) + end + end else - dλ !== nothing && (dλ[:,i] .= vec(ReverseDiff.deriv(tu))) - dy !== nothing && (dy[:,i] .= vec(ReverseDiff.value(output))) + if inplace_sensitivity(S) + for (i, λi) in enumerate(λ) + _dy, back = Zygote.pullback(y, p) do u, p + out_ = Zygote.Buffer(similar(prob.noise_rate_prototype)) + f(out_, u, p, t) + copy(out_[:, i]) + end + tmp1, tmp2 = back(λ)#issue with Zygote.Buffer + dgrad[:, i] .= vec(tmp2) + dλ !== nothing && (dλ[:, i] .= vec(tmp1)) + dy !== nothing && (dy[:, i] .= vec(_dy)) + end + else + for (i, λi) in enumerate(λ) + _dy, back = Zygote.pullback(y, p) do u, p + f(u, p, t)[:, i] + end + tmp1, tmp2 = back(λ) + dgrad[:, i] .= vec(tmp2) + if tmp1 === nothing + # if a column of the noise matrix is zero, Zygote returns nothing. + dλ !== nothing && (dλ[:, i] .= false) + else + dλ !== nothing && (dλ[:, i] .= vec(tmp1)) + end + dy !== nothing && (dy[:, i] .= vec(_dy)) + end + end end - end - return + return end - -function _jacNoise!(λ, y, p, t, S::TS, isnoise::ZygoteVJP, dgrad, dλ, dy) where TS<:SensitivityFunction - @unpack sensealg, f = S - prob = getprob(S) - - if StochasticDiffEq.is_diagonal_noise(prob) - if inplace_sensitivity(S) - for (i, λi) in enumerate(λ) - _dy, back = Zygote.pullback(y, p) do u, p - out_ = Zygote.Buffer(similar(u)) - f(out_, u, p, t) - copy(out_[i]) - end - tmp1,tmp2 = back(λi) #issue: tmp2 = zeros(p) - dgrad[:,i] .= vec(tmp2) - dλ !== nothing && (dλ[:,i] .= vec(tmp1)) - dy !== nothing && (dy[i] = _dy) - end +function accumulate_cost!(dλ, y, p, t, S::TS, + dgrad = nothing) where {TS <: SensitivityFunction} + @unpack dg, dg_val, g, g_grad_config = S.diffcache + if dg !== nothing + if !(dg isa Tuple) + dg(dg_val, y, p, t) + dλ .-= vec(dg_val) + else + dg[1](dg_val[1], y, p, t) + dλ .-= vec(dg_val[1]) + if dgrad !== nothing + dg[2](dg_val[2], y, p, t) + dgrad .-= vec(dg_val[2]) + end + end else - for (i, λi) in enumerate(λ) - _dy, back = Zygote.pullback(y, p) do u, p - f(u, p, t)[i] - end - tmp1,tmp2 = back(λi) - dgrad[:,i] .= vec(tmp2) - dλ !== nothing && (dλ[:,i] .= vec(tmp1)) - dy !== nothing && (dy[i] = _dy) - end + g.t = t + gradient!(dg_val, g, y, S.sensealg, g_grad_config) + dλ .-= vec(dg_val) end - else - if inplace_sensitivity(S) - for (i, λi) in enumerate(λ) - _dy, back = Zygote.pullback(y, p) do u, p - out_ = Zygote.Buffer(similar(prob.noise_rate_prototype)) - f(out_, u, p, t) - copy(out_[:,i]) - end - tmp1,tmp2 = back(λ)#issue with Zygote.Buffer - dgrad[:,i] .= vec(tmp2) - dλ !== nothing && (dλ[:,i] .= vec(tmp1)) - dy !== nothing && (dy[:,i] .= vec(_dy)) - end + return nothing +end + +function build_jac_config(alg, uf, u) + if alg_autodiff(alg) + jac_config = ForwardDiff.JacobianConfig(uf, u, u, + ForwardDiff.Chunk{ + determine_chunksize(u, + alg)}()) else - for (i, λi) in enumerate(λ) - _dy, back = Zygote.pullback(y, p) do u, p - f(u, p, t)[:,i] - end - tmp1,tmp2 = back(λ) - dgrad[:,i] .= vec(tmp2) - if tmp1 === nothing - # if a column of the noise matrix is zero, Zygote returns nothing. - dλ !== nothing && (dλ[:,i] .= false) + if diff_type(alg) != Val{:complex} + jac_config = FiniteDiff.JacobianCache(similar(u), similar(u), + similar(u), diff_type(alg)) else - dλ !== nothing && (dλ[:,i] .= vec(tmp1)) + tmp = Complex{eltype(u)}.(u) + du1 = Complex{eltype(u)}.(du1) + jac_config = FiniteDiff.JacobianCache(tmp, du1, nothing, diff_type(alg)) end - dy !== nothing && (dy[:,i] .= vec(_dy)) - end end - end - return + jac_config end - -function accumulate_cost!(dλ, y, p, t, S::TS, dgrad=nothing) where TS<:SensitivityFunction - @unpack dg, dg_val, g, g_grad_config = S.diffcache - if dg !== nothing - if !(dg isa Tuple) - dg(dg_val,y,p,t) - dλ .-= vec(dg_val) +function build_param_jac_config(alg, pf, u, p) + if alg_autodiff(alg) + jac_config = ForwardDiff.JacobianConfig(pf, u, p, + ForwardDiff.Chunk{ + determine_chunksize(p, + alg)}()) else - dg[1](dg_val[1],y,p,t) - dλ .-= vec(dg_val[1]) - if dgrad !== nothing - dg[2](dg_val[2],y,p,t) - dgrad .-= vec(dg_val[2]) - end + if diff_type(alg) != Val{:complex} + jac_config = FiniteDiff.JacobianCache(similar(p), similar(u), + similar(u), diff_type(alg)) + else + tmp = Complex{eltype(p)}.(p) + du1 = Complex{eltype(u)}.(u) + jac_config = FiniteDiff.JacobianCache(tmp, du1, nothing, diff_type(alg)) + end end - else - g.t = t - gradient!(dg_val, g, y, S.sensealg, g_grad_config) - dλ .-= vec(dg_val) - end - return nothing -end - -function build_jac_config(alg,uf,u) - if alg_autodiff(alg) - jac_config = ForwardDiff.JacobianConfig(uf,u,u, - ForwardDiff.Chunk{determine_chunksize(u,alg)}()) - else - if diff_type(alg) != Val{:complex} - jac_config = FiniteDiff.JacobianCache(similar(u),similar(u), - similar(u),diff_type(alg)) + jac_config +end + +function build_grad_config(alg, tf, du1, t) + if alg_autodiff(alg) + grad_config = ForwardDiff.GradientConfig(tf, du1, + ForwardDiff.Chunk{ + determine_chunksize(du1, + alg) + }()) else - tmp = Complex{eltype(u)}.(u) - du1 = Complex{eltype(u)}.(du1) - jac_config = FiniteDiff.JacobianCache(tmp,du1,nothing,diff_type(alg)) + grad_config = FiniteDiff.GradientCache(du1, t, diff_type(alg)) end - end - jac_config -end - -function build_param_jac_config(alg,pf,u,p) - if alg_autodiff(alg) - jac_config = ForwardDiff.JacobianConfig(pf,u,p, - ForwardDiff.Chunk{determine_chunksize(p,alg)}()) - else - if diff_type(alg) != Val{:complex} - jac_config = FiniteDiff.JacobianCache(similar(p),similar(u), - similar(u),diff_type(alg)) + grad_config +end + +function build_deriv_config(alg, tf, du1, t) + if alg_autodiff(alg) + grad_config = ForwardDiff.DerivativeConfig(tf, du1, t) else - tmp = Complex{eltype(p)}.(p) - du1 = Complex{eltype(u)}.(u) - jac_config = FiniteDiff.JacobianCache(tmp,du1,nothing,diff_type(alg)) + grad_config = FiniteDiff.DerivativeCache(du1, t, diff_type(alg)) end - end - jac_config -end - -function build_grad_config(alg,tf,du1,t) - if alg_autodiff(alg) - grad_config = ForwardDiff.GradientConfig(tf,du1, - ForwardDiff.Chunk{determine_chunksize(du1,alg)}()) - else - grad_config = FiniteDiff.GradientCache(du1,t,diff_type(alg)) - end - grad_config -end - -function build_deriv_config(alg,tf,du1,t) - if alg_autodiff(alg) - grad_config = ForwardDiff.DerivativeConfig(tf,du1,t) - else - grad_config = FiniteDiff.DerivativeCache(du1,t,diff_type(alg)) - end - grad_config + grad_config end diff --git a/src/forward_sensitivity.jl b/src/forward_sensitivity.jl index 5940c20bb..bef637f90 100644 --- a/src/forward_sensitivity.jl +++ b/src/forward_sensitivity.jl @@ -7,166 +7,177 @@ with the derivative terms. ODEForwardSensitivityFunction is not intended to be part of the public API. """ -struct ODEForwardSensitivityFunction{iip,F,A,Tt,OJ,J,JP,S,PJ,TW,TWt,UF,PF,JC,PJC,Alg,fc,JM,pJM,MM,CV} <: DiffEqBase.AbstractODEFunction{iip} - f::F - analytic::A - tgrad::Tt - original_jac::OJ - jac::J - jac_prototype::JP - sparsity::S - paramjac::PJ - Wfact::TW - Wfact_t::TWt - uf::UF - pf::PF - J::JM - pJ::pJM - jac_config::JC - paramjac_config::PJC - alg::Alg - numparams::Int - numindvar::Int - f_cache::fc - mass_matrix::MM - isautojacvec::Bool - isautojacmat::Bool - colorvec::CV +struct ODEForwardSensitivityFunction{iip, F, A, Tt, OJ, J, JP, S, PJ, TW, TWt, UF, PF, JC, + PJC, Alg, fc, JM, pJM, MM, CV} <: + DiffEqBase.AbstractODEFunction{iip} + f::F + analytic::A + tgrad::Tt + original_jac::OJ + jac::J + jac_prototype::JP + sparsity::S + paramjac::PJ + Wfact::TW + Wfact_t::TWt + uf::UF + pf::PF + J::JM + pJ::pJM + jac_config::JC + paramjac_config::PJC + alg::Alg + numparams::Int + numindvar::Int + f_cache::fc + mass_matrix::MM + isautojacvec::Bool + isautojacmat::Bool + colorvec::CV end has_original_jac(S) = isdefined(S, :original_jac) && S.jac !== nothing -struct NILSSForwardSensitivityFunction{iip,sensefunType,senseType,MM} <: DiffEqBase.AbstractODEFunction{iip} - S::sensefunType - sensealg::senseType - nus::Int - mass_matrix::MM -end - -function ODEForwardSensitivityFunction(f,analytic,tgrad,original_jac,jac,jac_prototype,sparsity,paramjac,Wfact,Wfact_t,uf,pf,u0, - jac_config,paramjac_config,alg,p,f_cache,mm, - isautojacvec,isautojacmat,colorvec,nus) - numparams = length(p) - numindvar = length(u0) - J = isautojacvec ? nothing : Matrix{eltype(u0)}(undef,numindvar,numindvar) - pJ = Matrix{eltype(u0)}(undef,numindvar,numparams) # number of funcs size - - sensefun = ODEForwardSensitivityFunction{isinplace(f),typeof(f),typeof(analytic), - typeof(tgrad),typeof(original_jac), - typeof(jac),typeof(jac_prototype),typeof(sparsity), - typeof(paramjac), - typeof(Wfact),typeof(Wfact_t),typeof(uf), - typeof(pf),typeof(jac_config), - typeof(paramjac_config),typeof(alg), - typeof(f_cache), - typeof(J),typeof(pJ),typeof(mm),typeof(f.colorvec)}( - f,analytic,tgrad,original_jac,jac,jac_prototype, - sparsity,paramjac,Wfact,Wfact_t,uf,pf,J,pJ, - jac_config,paramjac_config,alg, - numparams,numindvar,f_cache,mm,isautojacvec,isautojacmat,colorvec, - ) - if nus!==nothing - sensefun = NILSSForwardSensitivityFunction{isinplace(f), typeof(sensefun), - typeof(alg),typeof(mm)}(sensefun,alg,nus,mm) - end - - return sensefun -end - -function (S::ODEForwardSensitivityFunction)(du,u,p,t) - y = @view u[1:S.numindvar] # These are the independent variables - dy = @view du[1:S.numindvar] - S.f(dy,y,p,t) # Make the first part be the ODE - - # Now do sensitivities - # Compute the Jacobian - - if !S.isautojacvec && !S.isautojacmat - if has_original_jac(S) - S.original_jac(S.J,y,p,t) # Calculate the Jacobian into J +struct NILSSForwardSensitivityFunction{iip, sensefunType, senseType, MM} <: + DiffEqBase.AbstractODEFunction{iip} + S::sensefunType + sensealg::senseType + nus::Int + mass_matrix::MM +end + +function ODEForwardSensitivityFunction(f, analytic, tgrad, original_jac, jac, jac_prototype, + sparsity, paramjac, Wfact, Wfact_t, uf, pf, u0, + jac_config, paramjac_config, alg, p, f_cache, mm, + isautojacvec, isautojacmat, colorvec, nus) + numparams = length(p) + numindvar = length(u0) + J = isautojacvec ? nothing : Matrix{eltype(u0)}(undef, numindvar, numindvar) + pJ = Matrix{eltype(u0)}(undef, numindvar, numparams) # number of funcs size + + sensefun = ODEForwardSensitivityFunction{isinplace(f), typeof(f), typeof(analytic), + typeof(tgrad), typeof(original_jac), + typeof(jac), typeof(jac_prototype), + typeof(sparsity), + typeof(paramjac), + typeof(Wfact), typeof(Wfact_t), typeof(uf), + typeof(pf), typeof(jac_config), + typeof(paramjac_config), typeof(alg), + typeof(f_cache), + typeof(J), typeof(pJ), typeof(mm), + typeof(f.colorvec)}(f, analytic, tgrad, + original_jac, jac, + jac_prototype, + sparsity, paramjac, Wfact, + Wfact_t, uf, pf, J, pJ, + jac_config, + paramjac_config, alg, + numparams, numindvar, + f_cache, mm, isautojacvec, + isautojacmat, colorvec) + if nus !== nothing + sensefun = NILSSForwardSensitivityFunction{isinplace(f), typeof(sensefun), + typeof(alg), typeof(mm)}(sensefun, alg, + nus, mm) + end + + return sensefun +end + +function (S::ODEForwardSensitivityFunction)(du, u, p, t) + y = @view u[1:(S.numindvar)] # These are the independent variables + dy = @view du[1:(S.numindvar)] + S.f(dy, y, p, t) # Make the first part be the ODE + + # Now do sensitivities + # Compute the Jacobian + + if !S.isautojacvec && !S.isautojacmat + if has_original_jac(S) + S.original_jac(S.J, y, p, t) # Calculate the Jacobian into J + else + S.uf.t = t + jacobian!(S.J, S.uf, y, S.f_cache, S.alg, S.jac_config) + end + end + + if DiffEqBase.has_paramjac(S.f) + S.paramjac(S.pJ, y, p, t) # Calculate the parameter Jacobian into pJ else - S.uf.t = t - jacobian!(S.J, S.uf, y, S.f_cache, S.alg, S.jac_config) + S.pf.t = t + copyto!(S.pf.u, y) + jacobian!(S.pJ, S.pf, p, S.f_cache, S.alg, S.paramjac_config) end - end - - - if DiffEqBase.has_paramjac(S.f) - S.paramjac(S.pJ,y,p,t) # Calculate the parameter Jacobian into pJ - else - S.pf.t = t - copyto!(S.pf.u,y) - jacobian!(S.pJ, S.pf, p, S.f_cache, S.alg, S.paramjac_config) - end - - # Compute the parameter derivatives - if !S.isautojacvec && !S.isautojacmat - dp = @view du[reshape(S.numindvar+1:(length(p)+1)*S.numindvar,S.numindvar,length(p))] - Sj = @view u[reshape(S.numindvar+1:(length(p)+1)*S.numindvar,S.numindvar,length(p))] - mul!(dp,S.J,Sj) - DiffEqBase.@.. dp += S.pJ - elseif S.isautojacmat - S.uf.t = t - Sj = @view u[reshape(S.numindvar+1:end,S.numindvar,S.numparams)] - dp = @view du[reshape(S.numindvar+1:end,S.numindvar,S.numparams)] - jacobianmat!(dp, S.uf, y, Sj, S.alg, S.jac_config) - DiffEqBase.@.. dp += S.pJ - else - S.uf.t = t - for i in eachindex(p) - Sj = @view u[i*S.numindvar+1:(i+1)*S.numindvar] - dp = @view du[i*S.numindvar+1:(i+1)*S.numindvar] - jacobianvec!(dp, S.uf, y, Sj, S.alg, S.jac_config) - dp .+= @view S.pJ[:,i] + + # Compute the parameter derivatives + if !S.isautojacvec && !S.isautojacmat + dp = @view du[reshape((S.numindvar + 1):((length(p) + 1) * S.numindvar), + S.numindvar, length(p))] + Sj = @view u[reshape((S.numindvar + 1):((length(p) + 1) * S.numindvar), S.numindvar, + length(p))] + mul!(dp, S.J, Sj) + DiffEqBase.@.. dp += S.pJ + elseif S.isautojacmat + S.uf.t = t + Sj = @view u[reshape((S.numindvar + 1):end, S.numindvar, S.numparams)] + dp = @view du[reshape((S.numindvar + 1):end, S.numindvar, S.numparams)] + jacobianmat!(dp, S.uf, y, Sj, S.alg, S.jac_config) + DiffEqBase.@.. dp += S.pJ + else + S.uf.t = t + for i in eachindex(p) + Sj = @view u[(i * S.numindvar + 1):((i + 1) * S.numindvar)] + dp = @view du[(i * S.numindvar + 1):((i + 1) * S.numindvar)] + jacobianvec!(dp, S.uf, y, Sj, S.alg, S.jac_config) + dp .+= @view S.pJ[:, i] + end end - end - return nothing + return nothing end -@deprecate ODELocalSensitivityProblem(args...;kwargs...) ODEForwardSensitivityProblem(args...;kwargs...) +@deprecate ODELocalSensitivityProblem(args...; kwargs...) ODEForwardSensitivityProblem(args...; + kwargs...) -struct ODEForwardSensitivityProblem{iip,A} - sensealg::A +struct ODEForwardSensitivityProblem{iip, A} + sensealg::A end -function ODEForwardSensitivityProblem(f::F,args...;kwargs...) where F - ODEForwardSensitivityProblem(ODEFunction(f),args...;kwargs...) +function ODEForwardSensitivityProblem(f::F, args...; kwargs...) where {F} + ODEForwardSensitivityProblem(ODEFunction(f), args...; kwargs...) end -function ODEForwardSensitivityProblem(prob::ODEProblem,alg;kwargs...) - ODEForwardSensitivityProblem(prob.f,prob.u0,prob.tspan,prob.p,alg;kwargs...) +function ODEForwardSensitivityProblem(prob::ODEProblem, alg; kwargs...) + ODEForwardSensitivityProblem(prob.f, prob.u0, prob.tspan, prob.p, alg; kwargs...) end -const FORWARD_SENSITIVITY_PARAMETER_COMPATABILITY_MESSAGE = -""" -ODEForwardSensitivityProblem requires being able to solve -a differential equation defined by the parameter struct `p`. Thus while -DifferentialEquations.jl can support any parameter struct type, usage -with ODEForwardSensitivityProblem requires that `p` could be a valid -type for being the initial condition `u0` of an array. This means that -many simple types, such as `Tuple`s and `NamedTuple`s, will work as -parameters in normal contexts but will fail during ODEForwardSensitivityProblem -construction. To work around this issue for complicated cases like nested structs, -look into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl -or ComponentArrays.jl. -""" +const FORWARD_SENSITIVITY_PARAMETER_COMPATABILITY_MESSAGE = """ + ODEForwardSensitivityProblem requires being able to solve + a differential equation defined by the parameter struct `p`. Thus while + DifferentialEquations.jl can support any parameter struct type, usage + with ODEForwardSensitivityProblem requires that `p` could be a valid + type for being the initial condition `u0` of an array. This means that + many simple types, such as `Tuple`s and `NamedTuple`s, will work as + parameters in normal contexts but will fail during ODEForwardSensitivityProblem + construction. To work around this issue for complicated cases like nested structs, + look into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl + or ComponentArrays.jl. + """ struct ForwardSensitivityParameterCompatibilityError <: Exception end function Base.showerror(io::IO, e::ForwardSensitivityParameterCompatibilityError) - print(io, FORWARD_SENSITIVITY_PARAMETER_COMPATABILITY_MESSAGE) + print(io, FORWARD_SENSITIVITY_PARAMETER_COMPATABILITY_MESSAGE) end -const FORWARD_SENSITIVITY_OUT_OF_PLACE_MESSAGE = -""" -ODEForwardSensitivityProblem is not compatible with out of place ODE definitions, -i.e. `du=f(u,p,t)` definitions. It requires an in-place mutating function -`f(du,u,p,t)`. For more information on in-place vs out-of-place ODE definitions, -see the ODEProblem or ODEFunction documentation. -""" +const FORWARD_SENSITIVITY_OUT_OF_PLACE_MESSAGE = """ + ODEForwardSensitivityProblem is not compatible with out of place ODE definitions, + i.e. `du=f(u,p,t)` definitions. It requires an in-place mutating function + `f(du,u,p,t)`. For more information on in-place vs out-of-place ODE definitions, + see the ODEProblem or ODEFunction documentation. + """ struct ForwardSensitivityOutOfPlaceError <: Exception end function Base.showerror(io::IO, e::ForwardSensitivityOutOfPlaceError) - print(io, FORWARD_SENSITIVITY_OUT_OF_PLACE_MESSAGE) + print(io, FORWARD_SENSITIVITY_OUT_OF_PLACE_MESSAGE) end @doc doc""" @@ -326,137 +337,150 @@ at time `sol.t[i]`. Note that all of the functionality available to ODE solution is available in this case, including interpolations and plot recipes (the recipes will plot the expanded system). """ -function ODEForwardSensitivityProblem(f::F,u0, - tspan,p=nothing, - alg::ForwardSensitivity = ForwardSensitivity(); - nus=nothing, # determine if Nilss is used - w0=nothing, - v0=nothing, - kwargs...) where F<:DiffEqBase.AbstractODEFunction - isinplace = SciMLBase.isinplace(f) - # if there is an analytical Jacobian provided, we are not going to do automatic `jac*vec` - isautojacmat = get_jacmat(alg) - isautojacvec = get_jacvec(alg) - p === nothing && error("You must have parameters to use parameter sensitivity calculations!") - - if !(typeof(p) <: Union{Nothing,SciMLBase.NullParameters,AbstractArray}) || (p isa AbstractArray && !Base.isconcretetype(eltype(p))) - throw(ForwardSensitivityParameterCompatibilityError()) - end - - uf = DiffEqBase.UJacobianWrapper(f,tspan[1],p) - pf = DiffEqBase.ParamJacobianWrapper(f,tspan[1],copy(u0)) - if isautojacmat - if alg_autodiff(alg) - jac_config_seed = ForwardDiff.Dual{typeof(uf)}.(u0,[ntuple(x -> zero(eltype(u0)), length(p)) for i in eachindex(u0)]) - jac_config_buffer = similar(jac_config_seed) - jac_config = jac_config_seed, jac_config_buffer +function ODEForwardSensitivityProblem(f::F, u0, + tspan, p = nothing, + alg::ForwardSensitivity = ForwardSensitivity(); + nus = nothing, # determine if Nilss is used + w0 = nothing, + v0 = nothing, + kwargs...) where {F <: DiffEqBase.AbstractODEFunction} + isinplace = SciMLBase.isinplace(f) + # if there is an analytical Jacobian provided, we are not going to do automatic `jac*vec` + isautojacmat = get_jacmat(alg) + isautojacvec = get_jacvec(alg) + p === nothing && + error("You must have parameters to use parameter sensitivity calculations!") + + if !(typeof(p) <: Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || + (p isa AbstractArray && !Base.isconcretetype(eltype(p))) + throw(ForwardSensitivityParameterCompatibilityError()) + end + + uf = DiffEqBase.UJacobianWrapper(f, tspan[1], p) + pf = DiffEqBase.ParamJacobianWrapper(f, tspan[1], copy(u0)) + if isautojacmat + if alg_autodiff(alg) + jac_config_seed = ForwardDiff.Dual{typeof(uf) + }.(u0, + [ntuple(x -> zero(eltype(u0)), length(p)) + for i in eachindex(u0)]) + jac_config_buffer = similar(jac_config_seed) + jac_config = jac_config_seed, jac_config_buffer + else + error("Jacobian matrix products only work with automatic differentiation.") + end + elseif isautojacvec + if alg_autodiff(alg) + # if we are using automatic `jac*vec`, then we need to use a `jac_config` + # that is a tuple in the form of `(seed, buffer)` + jac_config_seed = ForwardDiff.Dual{typeof(jacobianvec!)}.(u0, u0) + jac_config_buffer = similar(jac_config_seed) + jac_config = jac_config_seed, jac_config_buffer + else + jac_config = (similar(u0), similar(u0)) + end + elseif DiffEqBase.has_jac(f) + jac_config = nothing else - error("Jacobian matrix products only work with automatic differentiation.") + jac_config = build_jac_config(alg, uf, u0) end - elseif isautojacvec - if alg_autodiff(alg) - # if we are using automatic `jac*vec`, then we need to use a `jac_config` - # that is a tuple in the form of `(seed, buffer)` - jac_config_seed = ForwardDiff.Dual{typeof(jacobianvec!)}.(u0,u0) - jac_config_buffer = similar(jac_config_seed) - jac_config = jac_config_seed, jac_config_buffer + + if DiffEqBase.has_paramjac(f) + paramjac_config = nothing else - jac_config = (similar(u0),similar(u0)) + paramjac_config = build_param_jac_config(alg, pf, u0, p) end - elseif DiffEqBase.has_jac(f) - jac_config = nothing - else - jac_config = build_jac_config(alg,uf,u0) - end - - if DiffEqBase.has_paramjac(f) - paramjac_config = nothing - else - paramjac_config = build_param_jac_config(alg,pf,u0,p) - end - - # TODO: make it better - if f.mass_matrix isa UniformScaling - mm = f.mass_matrix - else - nn = size(f.mass_matrix, 1) - mm = zeros(eltype(f.mass_matrix), (length(p)+1)*nn, (length(p)+1)*nn) - mm[1:nn, 1:nn] = f.mass_matrix - for i = 1:length(p) - mm[i*nn+1:(i+1)nn, i*nn+1:(i+1)nn] = f.mass_matrix + + # TODO: make it better + if f.mass_matrix isa UniformScaling + mm = f.mass_matrix + else + nn = size(f.mass_matrix, 1) + mm = zeros(eltype(f.mass_matrix), (length(p) + 1) * nn, (length(p) + 1) * nn) + mm[1:nn, 1:nn] = f.mass_matrix + for i in 1:length(p) + mm[(i * nn + 1):((i + 1)nn), (i * nn + 1):((i + 1)nn)] = f.mass_matrix + end + end + + # TODO: Use user tgrad. iW can be safely ignored here. + sense = ODEForwardSensitivityFunction(f, f.analytic, nothing, f.jac, nothing, + nothing, nothing, f.paramjac, + nothing, nothing, + uf, pf, u0, jac_config, + paramjac_config, alg, + p, similar(u0), mm, + isautojacvec, isautojacmat, f.colorvec, nus) + + if !SciMLBase.isinplace(sense) + throw(ForwardSensitivityOutOfPlaceError()) end - end - - # TODO: Use user tgrad. iW can be safely ignored here. - sense = ODEForwardSensitivityFunction(f,f.analytic,nothing,f.jac,nothing, - nothing,nothing,f.paramjac, - nothing,nothing, - uf,pf,u0,jac_config, - paramjac_config,alg, - p,similar(u0),mm, - isautojacvec,isautojacmat,f.colorvec,nus) - - if !SciMLBase.isinplace(sense) - throw(ForwardSensitivityOutOfPlaceError()) - end - - if nus===nothing - sense_u0 = [u0;zeros(eltype(u0),sense.numindvar*sense.numparams)] - else - if w0===nothing && v0===nothing - sense_u0 = [u0;zeros(eltype(u0),(nus+1)*sense.S.numindvar*sense.S.numparams)] + + if nus === nothing + sense_u0 = [u0; zeros(eltype(u0), sense.numindvar * sense.numparams)] else - sense_u0 = [u0;w0;v0] + if w0 === nothing && v0 === nothing + sense_u0 = [u0; + zeros(eltype(u0), + (nus + 1) * sense.S.numindvar * sense.S.numparams)] + else + sense_u0 = [u0; w0; v0] + end end - end - ODEProblem(sense,sense_u0,tspan,p, - ODEForwardSensitivityProblem{DiffEqBase.isinplace(f), - typeof(alg)}(alg); - kwargs...) + ODEProblem(sense, sense_u0, tspan, p, + ODEForwardSensitivityProblem{DiffEqBase.isinplace(f), + typeof(alg)}(alg); + kwargs...) end -function seed_duals(x::AbstractArray{V},f, - ::ForwardDiff.Chunk{N} = ForwardDiff.Chunk(x,typemax(Int64)), - ) where {V,T,N} - seeds = ForwardDiff.construct_seeds(ForwardDiff.Partials{N,V}) - duals = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x))))}.(vec(x),seeds) +function seed_duals(x::AbstractArray{V}, f, + ::ForwardDiff.Chunk{N} = ForwardDiff.Chunk(x, typemax(Int64))) where {V, + T, + N} + seeds = ForwardDiff.construct_seeds(ForwardDiff.Partials{N, V}) + duals = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(vec(x))))}.(vec(x), seeds) end has_continuous_callback(cb::DiscreteCallback) = false has_continuous_callback(cb::ContinuousCallback) = true has_continuous_callback(cb::CallbackSet) = !isempty(cb.continuous_callbacks) -function ODEForwardSensitivityProblem(f::DiffEqBase.AbstractODEFunction,u0, - tspan,p,alg::ForwardDiffSensitivity; - du0=zeros(eltype(u0),length(u0),length(p)), # perturbations of initial condition - dp=I(length(p)), # perturbations of parameters +function ODEForwardSensitivityProblem(f::DiffEqBase.AbstractODEFunction, u0, + tspan, p, alg::ForwardDiffSensitivity; + du0 = zeros(eltype(u0), length(u0), length(p)), # perturbations of initial condition + dp = I(length(p)), # perturbations of parameters kwargs...) - num_sen_par = size(du0,2) - if num_sen_par != size(dp,2) - error("Same number of perturbations of initial conditions and parameters required") - end - if size(du0,1) != length(u0) - error("Perturbations for all initial conditions required") - end - if size(dp,1) != length(p) - error("Perturbations for all parameters required") - end - - pdual = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f,eltype(vec(p))))}.(p, [ntuple(j -> dp[i,j], num_sen_par) for i in eachindex(p)]) - u0dual = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f,eltype(vec(u0))))}.(u0, [ntuple(j -> du0[i,j], num_sen_par) for i in eachindex(u0)]) - - if (convert_tspan(alg) === nothing && - haskey(kwargs,:callback) && has_continuous_callback(kwargs.callback) - ) || (convert_tspan(alg) !== nothing && convert_tspan(alg)) - tspandual = convert.(eltype(pdual),tspan) - else - tspandual = tspan - end - - prob_dual = ODEProblem(f,u0dual,tspan,pdual, - ODEForwardSensitivityProblem{DiffEqBase.isinplace(f), - typeof(alg)}(alg); - kwargs...) + num_sen_par = size(du0, 2) + if num_sen_par != size(dp, 2) + error("Same number of perturbations of initial conditions and parameters required") + end + if size(du0, 1) != length(u0) + error("Perturbations for all initial conditions required") + end + if size(dp, 1) != length(p) + error("Perturbations for all parameters required") + end + + pdual = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(vec(p)))) + }.(p, + [ntuple(j -> dp[i, j], num_sen_par) for i in eachindex(p)]) + u0dual = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(vec(u0)))) + }.(u0, + [ntuple(j -> du0[i, j], num_sen_par) + for i in eachindex(u0)]) + + if (convert_tspan(alg) === nothing && + haskey(kwargs, :callback) && has_continuous_callback(kwargs.callback)) || + (convert_tspan(alg) !== nothing && convert_tspan(alg)) + tspandual = convert.(eltype(pdual), tspan) + else + tspandual = tspan + end + + prob_dual = ODEProblem(f, u0dual, tspan, pdual, + ODEForwardSensitivityProblem{DiffEqBase.isinplace(f), + typeof(alg)}(alg); + kwargs...) end """ @@ -471,107 +495,128 @@ extract_local_sensitivities(sol, i::Integer, asmatrix::Val=Val(false)) # Decompo extract_local_sensitivities(sol, t::Union{Number,AbstractVector}, asmatrix::Val=Val(false)) # Decompose sol(t) ``` """ -extract_local_sensitivities(sol, asmatrix::Val=Val(false)) = extract_local_sensitivities(sol,sol.prob.problem_type.sensealg, asmatrix) -extract_local_sensitivities(sol, asmatrix::Bool) = extract_local_sensitivities(sol, Val{asmatrix}()) -extract_local_sensitivities(sol, i::Integer, asmatrix::Val=Val(false)) = _extract(sol, sol.prob.problem_type.sensealg, sol[i], asmatrix) -extract_local_sensitivities(sol, i::Integer, asmatrix::Bool) = extract_local_sensitivities(sol, i, Val{asmatrix}()) -extract_local_sensitivities(sol, t::Union{Number,AbstractVector}, asmatrix::Val=Val(false)) = _extract(sol, sol.prob.problem_type.sensealg, sol(t), asmatrix) -extract_local_sensitivities(sol, t, asmatrix::Bool) = extract_local_sensitivities(sol, t, Val{asmatrix}()) -extract_local_sensitivities(tmp, sol, t::Union{Number,AbstractVector}, asmatrix::Val=Val(false)) = _extract(sol, sol.prob.problem_type.sensealg, sol(tmp, t), asmatrix) -extract_local_sensitivities(tmp, sol, t, asmatrix::Bool) = extract_local_sensitivities(tmp, sol, t, Val{asmatrix}()) +function extract_local_sensitivities(sol, asmatrix::Val = Val(false)) + extract_local_sensitivities(sol, sol.prob.problem_type.sensealg, asmatrix) +end +function extract_local_sensitivities(sol, asmatrix::Bool) + extract_local_sensitivities(sol, Val{asmatrix}()) +end +function extract_local_sensitivities(sol, i::Integer, asmatrix::Val = Val(false)) + _extract(sol, sol.prob.problem_type.sensealg, sol[i], asmatrix) +end +function extract_local_sensitivities(sol, i::Integer, asmatrix::Bool) + extract_local_sensitivities(sol, i, Val{asmatrix}()) +end +function extract_local_sensitivities(sol, t::Union{Number, AbstractVector}, + asmatrix::Val = Val(false)) + _extract(sol, sol.prob.problem_type.sensealg, sol(t), asmatrix) +end +function extract_local_sensitivities(sol, t, asmatrix::Bool) + extract_local_sensitivities(sol, t, Val{asmatrix}()) +end +function extract_local_sensitivities(tmp, sol, t::Union{Number, AbstractVector}, + asmatrix::Val = Val(false)) + _extract(sol, sol.prob.problem_type.sensealg, sol(tmp, t), asmatrix) +end +function extract_local_sensitivities(tmp, sol, t, asmatrix::Bool) + extract_local_sensitivities(tmp, sol, t, Val{asmatrix}()) +end # Get ODE u vector and sensitivity values from all time points -function extract_local_sensitivities(sol,::ForwardSensitivity, ::Val{false}) - ni = sol.prob.f.numindvar - u = sol[1:ni, :] - du = [sol[ni*j+1:ni*(j+1),:] for j in 1:sol.prob.f.numparams] - return u, du -end - -function extract_local_sensitivities(sol,::ForwardDiffSensitivity, ::Val{false}) - u = ForwardDiff.value.(sol) - du_full = ForwardDiff.partials.(sol) - firststate = first(du_full) - firstparam = first(firststate) - Js = map(1:length(firstparam)) do j - map(CartesianIndices(du_full)) do II - du_full[II][j] +function extract_local_sensitivities(sol, ::ForwardSensitivity, ::Val{false}) + ni = sol.prob.f.numindvar + u = sol[1:ni, :] + du = [sol[(ni * j + 1):(ni * (j + 1)), :] for j in 1:(sol.prob.f.numparams)] + return u, du +end + +function extract_local_sensitivities(sol, ::ForwardDiffSensitivity, ::Val{false}) + u = ForwardDiff.value.(sol) + du_full = ForwardDiff.partials.(sol) + firststate = first(du_full) + firstparam = first(firststate) + Js = map(1:length(firstparam)) do j + map(CartesianIndices(du_full)) do II + du_full[II][j] + end end - end - return u, Js -end - -function extract_local_sensitivities(sol,::ForwardSensitivity, ::Val{true}) - prob = sol.prob - ni = prob.f.numindvar - pn = prob.f.numparams - jsize = (ni, pn) - sol[1:ni, :], map(sol.u) do u - collect(reshape((@view u[ni+1:end]), jsize)) - end -end - -function extract_local_sensitivities(sol,::ForwardDiffSensitivity, ::Val{true}) - retu = ForwardDiff.value.(sol) - jsize = length(sol.u[1]), ForwardDiff.npartials(sol.u[1][1]) - du = map(sol.u) do u - du_i = similar(retu, jsize) - for i in eachindex(u) - du_i[i, :] = ForwardDiff.partials(u[i]) + return u, Js +end + +function extract_local_sensitivities(sol, ::ForwardSensitivity, ::Val{true}) + prob = sol.prob + ni = prob.f.numindvar + pn = prob.f.numparams + jsize = (ni, pn) + sol[1:ni, :], map(sol.u) do u + collect(reshape((@view u[(ni + 1):end]), jsize)) end - du_i - end - retu, du +end + +function extract_local_sensitivities(sol, ::ForwardDiffSensitivity, ::Val{true}) + retu = ForwardDiff.value.(sol) + jsize = length(sol.u[1]), ForwardDiff.npartials(sol.u[1][1]) + du = map(sol.u) do u + du_i = similar(retu, jsize) + for i in eachindex(u) + du_i[i, :] = ForwardDiff.partials(u[i]) + end + du_i + end + retu, du end # Get ODE u vector and sensitivity values from sensitivity problem u vector -function _extract(sol, sensealg::ForwardSensitivity, su::AbstractVector, asmatrix::Val = Val(false)) - u = view(su, 1:sol.prob.f.numindvar) - du = _extract_du(sol, sensealg, su, asmatrix) - return u, du +function _extract(sol, sensealg::ForwardSensitivity, su::AbstractVector, + asmatrix::Val = Val(false)) + u = view(su, 1:(sol.prob.f.numindvar)) + du = _extract_du(sol, sensealg, su, asmatrix) + return u, du end -function _extract(sol, sensealg::ForwardDiffSensitivity, su::AbstractVector, asmatrix::Val = Val(false)) - u = ForwardDiff.value.(su) - du = _extract_du(sol, sensealg, su, asmatrix) - return u, du +function _extract(sol, sensealg::ForwardDiffSensitivity, su::AbstractVector, + asmatrix::Val = Val(false)) + u = ForwardDiff.value.(su) + du = _extract_du(sol, sensealg, su, asmatrix) + return u, du end # Get sensitivity values from sensitivity problem u vector (nested form) function _extract_du(sol, ::ForwardSensitivity, su::Vector, ::Val{false}) - ni = sol.prob.f.numindvar - return [view(su, ni*j+1:ni*(j+1)) for j in 1:sol.prob.f.numparams] + ni = sol.prob.f.numindvar + return [view(su, (ni * j + 1):(ni * (j + 1))) for j in 1:(sol.prob.f.numparams)] end function _extract_du(sol, ::ForwardDiffSensitivity, su::Vector, ::Val{false}) - du_full = ForwardDiff.partials.(su) - return [[du_full[i][j] for i in 1:size(du_full,1)] for j in 1:length(du_full[1])] + du_full = ForwardDiff.partials.(su) + return [[du_full[i][j] for i in 1:size(du_full, 1)] for j in 1:length(du_full[1])] end # Get sensitivity values from sensitivity problem u vector (matrix form) function _extract_du(sol, ::ForwardSensitivity, su::Vector, ::Val{true}) - ni = sol.prob.f.numindvar - np = sol.prob.f.numparams - return view(reshape(su, ni, np+1), :, 2:np+1) + ni = sol.prob.f.numindvar + np = sol.prob.f.numparams + return view(reshape(su, ni, np + 1), :, 2:(np + 1)) end function _extract_du(sol, ::ForwardDiffSensitivity, su::Vector, ::Val{true}) - du_full = ForwardDiff.partials.(su) - return [du_full[i][j] for i in 1:size(du_full,1), j in 1:length(du_full[1])] + du_full = ForwardDiff.partials.(su) + return [du_full[i][j] for i in 1:size(du_full, 1), j in 1:length(du_full[1])] end - ### Bonus Pieces -function SciMLBase.remake(prob::ODEProblem{uType,tType,isinplace,P,F,K,<:ODEForwardSensitivityProblem}; - f=nothing,tspan=nothing,u0=nothing,p=nothing,kwargs...) where - {uType,tType,isinplace,P,F,K} - _p = p === nothing ? prob.p : p - _f = f === nothing ? prob.f.f : f - _u0 = u0 === nothing ? prob.u0[1:prob.f.numindvar] : u0[1:prob.f.numindvar] +function SciMLBase.remake(prob::ODEProblem{uType, tType, isinplace, P, F, K, + <:ODEForwardSensitivityProblem}; + f = nothing, tspan = nothing, u0 = nothing, p = nothing, + kwargs...) where + {uType, tType, isinplace, P, F, K} + _p = p === nothing ? prob.p : p + _f = f === nothing ? prob.f.f : f + _u0 = u0 === nothing ? prob.u0[1:(prob.f.numindvar)] : u0[1:(prob.f.numindvar)] _tspan = tspan === nothing ? prob.tspan : tspan - ODEForwardSensitivityProblem(_f,_u0, - _tspan,_p,prob.problem_type.sensealg; - prob.kwargs...,kwargs...) + ODEForwardSensitivityProblem(_f, _u0, + _tspan, _p, prob.problem_type.sensealg; + prob.kwargs..., kwargs...) end SciMLBase.ODEFunction(f::ODEForwardSensitivityFunction; kwargs...) = f diff --git a/src/hasbranching.jl b/src/hasbranching.jl index 50bf35578..073b70005 100644 --- a/src/hasbranching.jl +++ b/src/hasbranching.jl @@ -12,7 +12,10 @@ end for (mod, f, n) in DiffRules.diffrules() isdefined(@__MODULE__, mod) || continue - @eval Cassette.overdub(::HasBranchingCtx, f::Core.Typeof($mod.$f), x::Vararg{Any, $n}) = f(x...) + @eval function Cassette.overdub(::HasBranchingCtx, f::Core.Typeof($mod.$f), + x::Vararg{Any, $n}) + f(x...) + end end function _pass(::Type{<:HasBranchingCtx}, reflection::Cassette.Reflection) @@ -20,27 +23,33 @@ function _pass(::Type{<:HasBranchingCtx}, reflection::Cassette.Reflection) if any(x -> isa(x, GotoIfNot), ir.code) printbranch && println("GotoIfNot detected in $(reflection.method)\nir = $ir\n") - Cassette.insert_statements!( - ir.code, ir.codelocs, - (stmt, i) -> i == 1 ? 3 : nothing, - (stmt, i) -> Any[ - Expr(:call, Expr(:nooverdub, GlobalRef(Base, :getfield)), Expr(:contextslot), QuoteNode(:metadata)), - Expr(:call, Expr(:nooverdub, GlobalRef(Base, :setindex!)), SSAValue(1), true, QuoteNode(:has_branching)), - stmt, - ], - ) - Cassette.insert_statements!( - ir.code, ir.codelocs, - (stmt, i) -> i > 2 && isa(stmt, Expr) ? 1 : nothing, - (stmt, i) -> begin - callstmt = Meta.isexpr(stmt, :(=)) ? stmt.args[2] : stmt - Meta.isexpr(stmt, :call) || Meta.isexpr(stmt, :invoke) || return Any[stmt] - callstmt = Expr(callstmt.head, Expr(:nooverdub, callstmt.args[1]), callstmt.args[2:end]...) - return Any[ - Meta.isexpr(stmt, :(=)) ? Expr(:(=), stmt.args[1], callstmt) : callstmt, - ] - end, - ) + Cassette.insert_statements!(ir.code, ir.codelocs, + (stmt, i) -> i == 1 ? 3 : nothing, + (stmt, i) -> Any[Expr(:call, + Expr(:nooverdub, + GlobalRef(Base, :getfield)), + Expr(:contextslot), + QuoteNode(:metadata)), + Expr(:call, + Expr(:nooverdub, + GlobalRef(Base, :setindex!)), + SSAValue(1), true, + QuoteNode(:has_branching)), + stmt]) + Cassette.insert_statements!(ir.code, ir.codelocs, + (stmt, i) -> i > 2 && isa(stmt, Expr) ? 1 : nothing, + (stmt, i) -> begin + callstmt = Meta.isexpr(stmt, :(=)) ? stmt.args[2] : + stmt + Meta.isexpr(stmt, :call) || + Meta.isexpr(stmt, :invoke) || return Any[stmt] + callstmt = Expr(callstmt.head, + Expr(:nooverdub, callstmt.args[1]), + callstmt.args[2:end]...) + return Any[Meta.isexpr(stmt, :(=)) ? + Expr(:(=), stmt.args[1], callstmt) : + callstmt] + end) end return ir end @@ -55,8 +64,15 @@ end Cassette.overdub(::HasBranchingCtx, ::typeof(+), x...) = +(x...) Cassette.overdub(::HasBranchingCtx, ::typeof(*), x...) = *(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(Base.materialize), x...) = Base.materialize(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(Base.literal_pow), x...) = Base.literal_pow(x...) +function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.materialize), x...) + Base.materialize(x...) +end +function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.literal_pow), x...) + Base.literal_pow(x...) +end Cassette.overdub(::HasBranchingCtx, ::typeof(Base.getindex), x...) = Base.getindex(x...) Cassette.overdub(::HasBranchingCtx, ::typeof(Core.Typeof), x...) = Core.Typeof(x...) -Cassette.overdub(::HasBranchingCtx, ::Type{Base.OneTo{T}}, stop) where {T <: Integer} = Base.OneTo{T}(stop) +function Cassette.overdub(::HasBranchingCtx, ::Type{Base.OneTo{T}}, + stop) where {T <: Integer} + Base.OneTo{T}(stop) +end diff --git a/src/interpolating_adjoint.jl b/src/interpolating_adjoint.jl index e0a630a78..ef315fbd2 100644 --- a/src/interpolating_adjoint.jl +++ b/src/interpolating_adjoint.jl @@ -1,532 +1,587 @@ -struct ODEInterpolatingAdjointSensitivityFunction{C<:AdjointDiffCache,Alg<:InterpolatingAdjoint, - uType,SType,CPS,pType,fType<:DiffEqBase.AbstractDiffEqFunction} <: SensitivityFunction - diffcache::C - sensealg::Alg - discrete::Bool - y::uType - sol::SType - checkpoint_sol::CPS - prob::pType - f::fType - noiseterm::Bool +struct ODEInterpolatingAdjointSensitivityFunction{C <: AdjointDiffCache, + Alg <: InterpolatingAdjoint, + uType, SType, CPS, pType, + fType <: + DiffEqBase.AbstractDiffEqFunction} <: + SensitivityFunction + diffcache::C + sensealg::Alg + discrete::Bool + y::uType + sol::SType + checkpoint_sol::CPS + prob::pType + f::fType + noiseterm::Bool end -mutable struct CheckpointSolution{S,I,T,T2} - cpsol::S # solution in a checkpoint interval - intervals::I # checkpoint intervals - cursor::Int # sol.prob.tspan = intervals[cursor] - tols::T - tstops::T2 # for callbacks +mutable struct CheckpointSolution{S, I, T, T2} + cpsol::S # solution in a checkpoint interval + intervals::I # checkpoint intervals + cursor::Int # sol.prob.tspan = intervals[cursor] + tols::T + tstops::T2 # for callbacks end -function ODEInterpolatingAdjointSensitivityFunction(g,sensealg,discrete,sol,dg,f,checkpoints,tols,tstops=nothing;noiseterm=false) - tspan = reverse(sol.prob.tspan) - checkpointing = ischeckpointing(sensealg, sol) - (checkpointing && checkpoints === nothing) && error("checkpoints must be passed when checkpointing is enabled.") - - checkpoint_sol = if checkpointing - intervals = map(tuple, @view(checkpoints[1:end-1]), @view(checkpoints[2:end])) - interval_end = intervals[end][end] - tspan[1] > interval_end && push!(intervals, (interval_end, tspan[1])) - cursor = lastindex(intervals) - interval = intervals[cursor] - - if typeof(sol.prob) <: Union{SDEProblem,RODEProblem} - # replicated noise - _sol = deepcopy(sol) - idx1 = searchsortedfirst(_sol.W.t, interval[1]-1000eps(interval[1])) - if typeof(sol.W) <: DiffEqNoiseProcess.NoiseProcess - sol.W.save_everystep = false - _sol.W.save_everystep = false - forwardnoise = DiffEqNoiseProcess.NoiseWrapper(_sol.W, indx=idx1) - elseif typeof(sol.W) <: DiffEqNoiseProcess.NoiseGrid - #idx2 = searchsortedfirst(_sol.W.t, interval[2]+1000eps(interval[1])) - forwardnoise = DiffEqNoiseProcess.NoiseGrid(_sol.W.t[idx1:end], _sol.W.W[idx1:end]) - else - error("NoiseProcess type not implemented.") - end - dt = choose_dt((_sol.W.t[idx1]-_sol.W.t[idx1+1]), _sol.W.t, interval) - - cpsol = solve(remake(sol.prob, tspan=interval, u0=sol(interval[1]), noise=forwardnoise), - sol.alg, save_noise=false; dt=dt, tstops=_sol.t[idx1:end] ,tols...) - else - if tstops === nothing - cpsol = solve(remake(sol.prob, tspan=interval, u0=sol(interval[1])),sol.alg; tols...) - else - if maximum(interval[1] .< tstops .< interval[2]) - # callback might have changed p - _p = reset_p(sol.prob.kwargs[:callback], interval) - cpsol = solve(remake(sol.prob, tspan=interval, u0=sol(interval[1])),tstops=tstops, - p=_p, sol.alg; tols...) +function ODEInterpolatingAdjointSensitivityFunction(g, sensealg, discrete, sol, dg, f, + checkpoints, tols, tstops = nothing; + noiseterm = false) + tspan = reverse(sol.prob.tspan) + checkpointing = ischeckpointing(sensealg, sol) + (checkpointing && checkpoints === nothing) && + error("checkpoints must be passed when checkpointing is enabled.") + + checkpoint_sol = if checkpointing + intervals = map(tuple, @view(checkpoints[1:(end - 1)]), @view(checkpoints[2:end])) + interval_end = intervals[end][end] + tspan[1] > interval_end && push!(intervals, (interval_end, tspan[1])) + cursor = lastindex(intervals) + interval = intervals[cursor] + + if typeof(sol.prob) <: Union{SDEProblem, RODEProblem} + # replicated noise + _sol = deepcopy(sol) + idx1 = searchsortedfirst(_sol.W.t, interval[1] - 1000eps(interval[1])) + if typeof(sol.W) <: DiffEqNoiseProcess.NoiseProcess + sol.W.save_everystep = false + _sol.W.save_everystep = false + forwardnoise = DiffEqNoiseProcess.NoiseWrapper(_sol.W, indx = idx1) + elseif typeof(sol.W) <: DiffEqNoiseProcess.NoiseGrid + #idx2 = searchsortedfirst(_sol.W.t, interval[2]+1000eps(interval[1])) + forwardnoise = DiffEqNoiseProcess.NoiseGrid(_sol.W.t[idx1:end], + _sol.W.W[idx1:end]) + else + error("NoiseProcess type not implemented.") + end + dt = choose_dt((_sol.W.t[idx1] - _sol.W.t[idx1 + 1]), _sol.W.t, interval) + + cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1]), + noise = forwardnoise), + sol.alg, save_noise = false; dt = dt, tstops = _sol.t[idx1:end], + tols...) else - cpsol = solve(remake(sol.prob, tspan=interval, u0=sol(interval[1])),tstops=tstops, sol.alg; tols...) + if tstops === nothing + cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1])), + sol.alg; tols...) + else + if maximum(interval[1] .< tstops .< interval[2]) + # callback might have changed p + _p = reset_p(sol.prob.kwargs[:callback], interval) + cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1])), + tstops = tstops, + p = _p, sol.alg; tols...) + else + cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1])), + tstops = tstops, sol.alg; tols...) + end + end end - end - end - CheckpointSolution(cpsol, intervals, cursor, tols, tstops) + CheckpointSolution(cpsol, intervals, cursor, tols, tstops) - else - nothing - end + else + nothing + end - diffcache, y = adjointdiffcache(g,sensealg,discrete,sol,dg,f;quad=false,noiseterm=noiseterm) + diffcache, y = adjointdiffcache(g, sensealg, discrete, sol, dg, f; quad = false, + noiseterm = noiseterm) - return ODEInterpolatingAdjointSensitivityFunction(diffcache,sensealg, - discrete,y,sol, - checkpoint_sol,sol.prob,f,noiseterm) + return ODEInterpolatingAdjointSensitivityFunction(diffcache, sensealg, + discrete, y, sol, + checkpoint_sol, sol.prob, f, + noiseterm) end function findcursor(intervals, t) - # equivalent with `findfirst(x->x[1] <= t <= x[2], intervals)` - lt(x, t) = <(x[2], t) - return searchsortedfirst(intervals, t, lt=lt) + # equivalent with `findfirst(x->x[1] <= t <= x[2], intervals)` + lt(x, t) = <(x[2], t) + return searchsortedfirst(intervals, t, lt = lt) end function choose_dt(dt, ts, interval) - if dt < 1000eps(interval[2]) - if length(ts) > 2 - dt = ts[end-1]-ts[end-2] - if dt < 1000eps(interval[2]) - dt = interval[2] - interval[1] - end - else - dt = interval[2] - interval[1] + if dt < 1000eps(interval[2]) + if length(ts) > 2 + dt = ts[end - 1] - ts[end - 2] + if dt < 1000eps(interval[2]) + dt = interval[2] - interval[1] + end + else + dt = interval[2] - interval[1] + end end - end - return dt + return dt end # u = λ' # add tstop on all the checkpoints -function (S::ODEInterpolatingAdjointSensitivityFunction)(du,u,p,t) - @unpack sol,checkpoint_sol, discrete, prob, f = S - - λ,grad,y,dλ,dgrad,dy = split_states(du,u,t,S) - - if S.noiseterm - if length(u) == length(du) - vecjacobian!(dλ, y, λ, p, t, S, dgrad=dgrad) - elseif length(u) != length(du) && StochasticDiffEq.is_diagonal_noise(prob) && !isnoisemixing(S.sensealg) - vecjacobian!(dλ, y, λ, p, t, S) - jacNoise!(λ, y, p, t, S, dgrad=dgrad) +function (S::ODEInterpolatingAdjointSensitivityFunction)(du, u, p, t) + @unpack sol, checkpoint_sol, discrete, prob, f = S + + λ, grad, y, dλ, dgrad, dy = split_states(du, u, t, S) + + if S.noiseterm + if length(u) == length(du) + vecjacobian!(dλ, y, λ, p, t, S, dgrad = dgrad) + elseif length(u) != length(du) && StochasticDiffEq.is_diagonal_noise(prob) && + !isnoisemixing(S.sensealg) + vecjacobian!(dλ, y, λ, p, t, S) + jacNoise!(λ, y, p, t, S, dgrad = dgrad) + else + jacNoise!(λ, y, p, t, S, dgrad = dgrad, dλ = dλ) + end else - jacNoise!(λ, y, p, t, S, dgrad=dgrad, dλ=dλ) + vecjacobian!(dλ, y, λ, p, t, S, dgrad = dgrad) end - else - vecjacobian!(dλ, y, λ, p, t, S, dgrad=dgrad) - end - dλ .*= -one(eltype(λ)) - dgrad .*= -one(eltype(dgrad)) + dλ .*= -one(eltype(λ)) + dgrad .*= -one(eltype(dgrad)) - discrete || accumulate_cost!(dλ, y, p, t, S, dgrad) - return nothing + discrete || accumulate_cost!(dλ, y, p, t, S, dgrad) + return nothing end -function (S::ODEInterpolatingAdjointSensitivityFunction)(du,u,p,t,W) - @unpack sol,checkpoint_sol, discrete, prob, f = S +function (S::ODEInterpolatingAdjointSensitivityFunction)(du, u, p, t, W) + @unpack sol, checkpoint_sol, discrete, prob, f = S - λ,grad,y,dλ,dgrad,dy = split_states(du,u,t,S) + λ, grad, y, dλ, dgrad, dy = split_states(du, u, t, S) - vecjacobian!(dλ, y, λ, p, t, S, dgrad=dgrad, W=W) + vecjacobian!(dλ, y, λ, p, t, S, dgrad = dgrad, W = W) - dλ .*= -one(eltype(λ)) - dgrad .*= -one(eltype(dgrad)) + dλ .*= -one(eltype(λ)) + dgrad .*= -one(eltype(dgrad)) - discrete || accumulate_cost!(dλ, y, p, t, S, dgrad) - return nothing + discrete || accumulate_cost!(dλ, y, p, t, S, dgrad) + return nothing end -function split_states(du,u,t,S::TS;update=true) where TS<:ODEInterpolatingAdjointSensitivityFunction - @unpack sol, y, checkpoint_sol, discrete, prob, f = S - idx = length(y) - - if update - if checkpoint_sol === nothing - if typeof(t) <: ForwardDiff.Dual && eltype(S.y) <: AbstractFloat - y = sol(t, continuity=:right) - else - sol(y,t, continuity=:right) - end - else - intervals = checkpoint_sol.intervals - interval = intervals[checkpoint_sol.cursor] - if !(interval[1] <= t <= interval[2]) - cursor′ = findcursor(intervals, t) - interval = intervals[cursor′] - cpsol_t = checkpoint_sol.cpsol.t - if typeof(t) <: ForwardDiff.Dual && eltype(S.y) <: AbstractFloat - y = sol(interval[1]) - else - sol(y, interval[1]) - end - if typeof(sol.prob) <: Union{SDEProblem,RODEProblem} - #idx1 = searchsortedfirst(sol.t, interval[1]) - _sol = deepcopy(sol) - idx1 = searchsortedfirst(_sol.t, interval[1]-100eps(interval[1])) - idx2 = searchsortedfirst(_sol.t, interval[2]+100eps(interval[2])) - idx_noise = searchsortedfirst(_sol.W.t, interval[1]-100eps(interval[1])) - if typeof(sol.W) <: DiffEqNoiseProcess.NoiseProcess - _sol.W.save_everystep = false - forwardnoise = DiffEqNoiseProcess.NoiseWrapper(_sol.W, indx=idx_noise) - elseif typeof(sol.W) <: DiffEqNoiseProcess.NoiseGrid - forwardnoise = DiffEqNoiseProcess.NoiseGrid(_sol.W.t[idx_noise:end], _sol.W.W[idx_noise:end]) - else - error("NoiseProcess type not implemented.") - end - prob′ = remake(prob, tspan=intervals[cursor′], u0=y, noise=forwardnoise) - dt = choose_dt(abs(cpsol_t[1]-cpsol_t[2]), cpsol_t, interval) - cpsol′ = solve(prob′, sol.alg, save_noise=false; dt=dt, tstops=_sol.t[idx1:idx2], checkpoint_sol.tols...) - else - if checkpoint_sol.tstops===nothing - prob′ = remake(prob, tspan=intervals[cursor′], u0=y) - cpsol′ = solve(prob′, sol.alg; dt=abs(cpsol_t[end] - cpsol_t[end-1]), checkpoint_sol.tols...) - else - if maximum(interval[1] .< checkpoint_sol.tstops .< interval[2]) - # callback might have changed p - _p = reset_p(prob.kwargs[:callback], interval) - prob′ = remake(prob, tspan=intervals[cursor′], u0=y, p=_p) - cpsol′ = solve(prob′, sol.alg; dt=abs(cpsol_t[end] - cpsol_t[end-1]), tstops=checkpoint_sol.tstops, checkpoint_sol.tols...) +function split_states(du, u, t, S::TS; + update = true) where {TS <: + ODEInterpolatingAdjointSensitivityFunction} + @unpack sol, y, checkpoint_sol, discrete, prob, f = S + idx = length(y) + + if update + if checkpoint_sol === nothing + if typeof(t) <: ForwardDiff.Dual && eltype(S.y) <: AbstractFloat + y = sol(t, continuity = :right) else - prob′ = remake(prob, tspan=intervals[cursor′], u0=y) - cpsol′ = solve(prob′, sol.alg; dt=abs(cpsol_t[end] - cpsol_t[end-1]), tstops=checkpoint_sol.tstops, checkpoint_sol.tols...) + sol(y, t, continuity = :right) + end + else + intervals = checkpoint_sol.intervals + interval = intervals[checkpoint_sol.cursor] + if !(interval[1] <= t <= interval[2]) + cursor′ = findcursor(intervals, t) + interval = intervals[cursor′] + cpsol_t = checkpoint_sol.cpsol.t + if typeof(t) <: ForwardDiff.Dual && eltype(S.y) <: AbstractFloat + y = sol(interval[1]) + else + sol(y, interval[1]) + end + if typeof(sol.prob) <: Union{SDEProblem, RODEProblem} + #idx1 = searchsortedfirst(sol.t, interval[1]) + _sol = deepcopy(sol) + idx1 = searchsortedfirst(_sol.t, interval[1] - 100eps(interval[1])) + idx2 = searchsortedfirst(_sol.t, interval[2] + 100eps(interval[2])) + idx_noise = searchsortedfirst(_sol.W.t, + interval[1] - 100eps(interval[1])) + if typeof(sol.W) <: DiffEqNoiseProcess.NoiseProcess + _sol.W.save_everystep = false + forwardnoise = DiffEqNoiseProcess.NoiseWrapper(_sol.W, + indx = idx_noise) + elseif typeof(sol.W) <: DiffEqNoiseProcess.NoiseGrid + forwardnoise = DiffEqNoiseProcess.NoiseGrid(_sol.W.t[idx_noise:end], + _sol.W.W[idx_noise:end]) + else + error("NoiseProcess type not implemented.") + end + prob′ = remake(prob, tspan = intervals[cursor′], u0 = y, + noise = forwardnoise) + dt = choose_dt(abs(cpsol_t[1] - cpsol_t[2]), cpsol_t, interval) + cpsol′ = solve(prob′, sol.alg, save_noise = false; dt = dt, + tstops = _sol.t[idx1:idx2], checkpoint_sol.tols...) + else + if checkpoint_sol.tstops === nothing + prob′ = remake(prob, tspan = intervals[cursor′], u0 = y) + cpsol′ = solve(prob′, sol.alg; + dt = abs(cpsol_t[end] - cpsol_t[end - 1]), + checkpoint_sol.tols...) + else + if maximum(interval[1] .< checkpoint_sol.tstops .< interval[2]) + # callback might have changed p + _p = reset_p(prob.kwargs[:callback], interval) + prob′ = remake(prob, tspan = intervals[cursor′], u0 = y, p = _p) + cpsol′ = solve(prob′, sol.alg; + dt = abs(cpsol_t[end] - cpsol_t[end - 1]), + tstops = checkpoint_sol.tstops, + checkpoint_sol.tols...) + else + prob′ = remake(prob, tspan = intervals[cursor′], u0 = y) + cpsol′ = solve(prob′, sol.alg; + dt = abs(cpsol_t[end] - cpsol_t[end - 1]), + tstops = checkpoint_sol.tstops, + checkpoint_sol.tols...) + end + end + end + checkpoint_sol.cpsol = cpsol′ + checkpoint_sol.cursor = cursor′ end - end + checkpoint_sol.cpsol(y, t, continuity = :right) end - checkpoint_sol.cpsol = cpsol′ - checkpoint_sol.cursor = cursor′ - end - checkpoint_sol.cpsol(y, t, continuity=:right) end - end - λ = @view u[1:idx] - grad = @view u[idx+1:end] + λ = @view u[1:idx] + grad = @view u[(idx + 1):end] - if length(u) == length(du) - dλ = @view du[1:idx] - dgrad = @view du[idx+1:end] + if length(u) == length(du) + dλ = @view du[1:idx] + dgrad = @view du[(idx + 1):end] - elseif length(u) != length(du) && StochasticDiffEq.is_diagonal_noise(prob) && !isnoisemixing(S.sensealg) - idx1 = [length(u)*(i-1)+i for i in 1:idx] # for diagonal indices of [1:idx,1:idx] + elseif length(u) != length(du) && StochasticDiffEq.is_diagonal_noise(prob) && + !isnoisemixing(S.sensealg) + idx1 = [length(u) * (i - 1) + i for i in 1:idx] # for diagonal indices of [1:idx,1:idx] - dλ = @view du[idx1] - dgrad = @view du[idx+1:end,1:idx] + dλ = @view du[idx1] + dgrad = @view du[(idx + 1):end, 1:idx] - elseif typeof(du) <: AbstractMatrix - # non-diagonal noise and noise mixing case - dλ = @view du[1:idx,1:idx] - dgrad = @view du[idx+1:end,1:idx] - end + elseif typeof(du) <: AbstractMatrix + # non-diagonal noise and noise mixing case + dλ = @view du[1:idx, 1:idx] + dgrad = @view du[(idx + 1):end, 1:idx] + end - λ,grad,y,dλ,dgrad,nothing + λ, grad, y, dλ, dgrad, nothing end # g is either g(t,u,p) or discrete g(t,u,i) -@noinline function ODEAdjointProblem(sol,sensealg::InterpolatingAdjoint, - t=nothing, - dg_discrete::DG1=nothing,dg_continuous::DG2=nothing, - g::G=nothing; - checkpoints=sol.t, - callback=CallbackSet(), - reltol=nothing, abstol=nothing, - kwargs...) where {DG1,DG2,G} - - dg_discrete===nothing && dg_continuous===nothing && g===nothing && error("Either `dg_discrete`, `dg_continuous`, or `g` must be specified.") - - @unpack f, p, u0, tspan = sol.prob - tspan = reverse(tspan) - discrete = (t !== nothing && dg_continuous === nothing) - - # remove duplicates from checkpoints - if ischeckpointing(sensealg, sol) && (length(unique(checkpoints)) != length(checkpoints)) - _checkpoints, duplicate_iterator_times = separate_nonunique(checkpoints) - tstops = duplicate_iterator_times[1] - checkpoints = filter(x -> x ∉ tstops, _checkpoints) - # check if start is in checkpoints. Otherwise first interval is missed. - if checkpoints[1] != tspan[2] - pushfirst!(checkpoints, tspan[2]) +@noinline function ODEAdjointProblem(sol, sensealg::InterpolatingAdjoint, + t = nothing, + dg_discrete::DG1 = nothing, + dg_continuous::DG2 = nothing, + g::G = nothing; + checkpoints = sol.t, + callback = CallbackSet(), + reltol = nothing, abstol = nothing, + kwargs...) where {DG1, DG2, G} + dg_discrete === nothing && dg_continuous === nothing && g === nothing && + error("Either `dg_discrete`, `dg_continuous`, or `g` must be specified.") + + @unpack f, p, u0, tspan = sol.prob + tspan = reverse(tspan) + discrete = (t !== nothing && dg_continuous === nothing) + + # remove duplicates from checkpoints + if ischeckpointing(sensealg, sol) && + (length(unique(checkpoints)) != length(checkpoints)) + _checkpoints, duplicate_iterator_times = separate_nonunique(checkpoints) + tstops = duplicate_iterator_times[1] + checkpoints = filter(x -> x ∉ tstops, _checkpoints) + # check if start is in checkpoints. Otherwise first interval is missed. + if checkpoints[1] != tspan[2] + pushfirst!(checkpoints, tspan[2]) + end + + if haskey(kwargs, :tstops) + (tstops !== kwargs[:tstops]) && unique!(push!(tstops, kwargs[:tstops]...)) + end + + else + tstops = nothing + end + + numstates = length(u0) + numparams = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(p) + + len = numstates + numparams + + λ = p === nothing || p === DiffEqBase.NullParameters() ? similar(u0) : + one(eltype(u0)) .* similar(p, len) + λ .= false + + sense = ODEInterpolatingAdjointSensitivityFunction(g, sensealg, discrete, sol, + dg_continuous, f, + checkpoints, + (reltol = reltol, abstol = abstol), + tstops) + + init_cb = (discrete || dg_discrete !== nothing) + cb, duplicate_iterator_times = generate_callbacks(sense, dg_discrete, λ, t, tspan[2], + callback, init_cb) + z0 = vec(zero(λ)) + original_mm = sol.prob.f.mass_matrix + if original_mm === I || original_mm === (I, I) + mm = I + else + adjmm = copy(sol.prob.f.mass_matrix') + zzz = similar(adjmm, numstates, numparams) + fill!(zzz, zero(eltype(zzz))) + # using concrate I is slightly more efficient + II = Diagonal(I, numparams) + mm = [adjmm zzz + copy(zzz') II] end - if haskey(kwargs, :tstops) - (tstops !== kwargs[:tstops]) && unique!(push!(tstops, kwargs[:tstops]...)) + jac_prototype = sol.prob.f.jac_prototype + if !sense.discrete || jac_prototype === nothing + adjoint_jac_prototype = nothing + else + _adjoint_jac_prototype = copy(jac_prototype') + zzz = similar(_adjoint_jac_prototype, numstates, numparams) + fill!(zzz, zero(eltype(zzz))) + II = Diagonal(I, numparams) + adjoint_jac_prototype = [_adjoint_jac_prototype zzz + copy(zzz') II] end - else - tstops = nothing - end - - numstates = length(u0) - numparams = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(p) - - len = numstates+numparams - - λ = p === nothing || p === DiffEqBase.NullParameters() ? similar(u0) : one(eltype(u0)) .* similar(p, len) - λ .= false - - sense = ODEInterpolatingAdjointSensitivityFunction(g,sensealg,discrete,sol,dg_continuous,f, - checkpoints, - (reltol=reltol,abstol=abstol), - tstops) - - init_cb = (discrete || dg_discrete!==nothing) - cb, duplicate_iterator_times = generate_callbacks(sense, dg_discrete, λ, t, tspan[2], callback, init_cb) - z0 = vec(zero(λ)) - original_mm = sol.prob.f.mass_matrix - if original_mm === I || original_mm === (I,I) - mm = I - else - adjmm = copy(sol.prob.f.mass_matrix') - zzz = similar(adjmm, numstates, numparams) - fill!(zzz, zero(eltype(zzz))) - # using concrate I is slightly more efficient - II = Diagonal(I, numparams) - mm = [adjmm zzz - copy(zzz') II] - end - - jac_prototype = sol.prob.f.jac_prototype - if !sense.discrete || jac_prototype === nothing - adjoint_jac_prototype = nothing - else - _adjoint_jac_prototype = copy(jac_prototype') - zzz = similar(_adjoint_jac_prototype, numstates, numparams) - fill!(zzz, zero(eltype(zzz))) - II = Diagonal(I, numparams) - adjoint_jac_prototype = [_adjoint_jac_prototype zzz - copy(zzz') II] - end - - odefun = ODEFunction(sense, mass_matrix=mm, jac_prototype=adjoint_jac_prototype) - return ODEProblem(odefun,z0,tspan,p,callback=cb) + odefun = ODEFunction(sense, mass_matrix = mm, jac_prototype = adjoint_jac_prototype) + return ODEProblem(odefun, z0, tspan, p, callback = cb) end +@noinline function SDEAdjointProblem(sol, sensealg::InterpolatingAdjoint, + t = nothing, + dg_discrete::DG1 = nothing, + dg_continuous::DG2 = nothing, + g::G = nothing; + checkpoints = sol.t, + callback = CallbackSet(), + reltol = nothing, abstol = nothing, + diffusion_jac = nothing, diffusion_paramjac = nothing, + kwargs...) where {DG1, DG2, G} + dg_discrete === nothing && dg_continuous === nothing && g === nothing && + error("Either `dg_discrete`, `dg_continuous`, or `g` must be specified.") + + @unpack f, p, u0, tspan = sol.prob + tspan = reverse(tspan) + discrete = (t !== nothing && dg_continuous === nothing) + + # remove duplicates from checkpoints + if ischeckpointing(sensealg, sol) && + (length(unique(checkpoints)) != length(checkpoints)) + _checkpoints, duplicate_iterator_times = separate_nonunique(checkpoints) + tstops = duplicate_iterator_times[1] + checkpoints = filter(x -> x ∉ tstops, _checkpoints) + # check if start is in checkpoints. Otherwise first interval is missed. + if checkpoints[1] != tspan[2] + pushfirst!(checkpoints, tspan[2]) + end + else + tstops = nothing + end -@noinline function SDEAdjointProblem(sol,sensealg::InterpolatingAdjoint, - t=nothing, - dg_discrete::DG1=nothing,dg_continuous::DG2=nothing, - g::G=nothing; - checkpoints=sol.t, - callback=CallbackSet(), - reltol=nothing, abstol=nothing, - diffusion_jac=nothing, diffusion_paramjac=nothing, - kwargs...) where {DG1,DG2,G} - - dg_discrete===nothing && dg_continuous===nothing && g===nothing && error("Either `dg_discrete`, `dg_continuous`, or `g` must be specified.") - - @unpack f, p, u0, tspan = sol.prob - tspan = reverse(tspan) - discrete = (t !== nothing && dg_continuous === nothing) - - # remove duplicates from checkpoints - if ischeckpointing(sensealg,sol) && (length(unique(checkpoints)) != length(checkpoints)) - _checkpoints, duplicate_iterator_times = separate_nonunique(checkpoints) - tstops = duplicate_iterator_times[1] - checkpoints = filter(x->x ∉ tstops, _checkpoints) - # check if start is in checkpoints. Otherwise first interval is missed. - if checkpoints[1] != tspan[2] - pushfirst!(checkpoints,tspan[2]) + numstates = length(u0) + numparams = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(p) + + len = numstates + numparams + + λ = one(eltype(u0)) .* similar(p, len) + λ .= false + + sense_drift = ODEInterpolatingAdjointSensitivityFunction(g, sensealg, discrete, sol, + dg_continuous, sol.prob.f, + checkpoints, + (reltol = reltol, + abstol = abstol)) + + diffusion_function = ODEFunction(sol.prob.g, jac = diffusion_jac, + paramjac = diffusion_paramjac) + sense_diffusion = ODEInterpolatingAdjointSensitivityFunction(g, sensealg, discrete, sol, + dg_continuous, + diffusion_function, + checkpoints, + (reltol = reltol, + abstol = abstol); + noiseterm = true) + + init_cb = (discrete || dg_discrete !== nothing) # && tspan[1] == t[end] + cb, duplicate_iterator_times = generate_callbacks(sense_drift, dg_discrete, λ, t, + tspan[2], callback, init_cb) + z0 = vec(zero(λ)) + original_mm = sol.prob.f.mass_matrix + if original_mm === I || original_mm === (I, I) + mm = I + else + adjmm = copy(sol.prob.f.mass_matrix') + zzz = similar(adjmm, numstates, numparams) + fill!(zzz, zero(eltype(zzz))) + # using concrate I is slightly more efficient + II = Diagonal(I, numparams) + mm = [adjmm zzz + copy(zzz') II] end - else - tstops = nothing - end - - numstates = length(u0) - numparams = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(p) - - len = numstates+numparams - - λ = one(eltype(u0)) .* similar(p, len) - λ .= false - - sense_drift = ODEInterpolatingAdjointSensitivityFunction(g,sensealg,discrete,sol,dg_continuous,sol.prob.f, - checkpoints,(reltol=reltol,abstol=abstol)) - - diffusion_function = ODEFunction(sol.prob.g, jac=diffusion_jac, paramjac=diffusion_paramjac) - sense_diffusion = ODEInterpolatingAdjointSensitivityFunction(g,sensealg,discrete,sol,dg_continuous,diffusion_function, - checkpoints,(reltol=reltol,abstol=abstol);noiseterm=true) - - init_cb = (discrete || dg_discrete!==nothing) # && tspan[1] == t[end] - cb, duplicate_iterator_times = generate_callbacks(sense_drift, dg_discrete, λ, t, tspan[2], callback, init_cb) - z0 = vec(zero(λ)) - original_mm = sol.prob.f.mass_matrix - if original_mm === I || original_mm === (I,I) - mm = I - else - adjmm = copy(sol.prob.f.mass_matrix') - zzz = similar(adjmm, numstates, numparams) - fill!(zzz, zero(eltype(zzz))) - # using concrate I is slightly more efficient - II = Diagonal(I, numparams) - mm = [adjmm zzz - copy(zzz') II] - end - - jac_prototype = sol.prob.f.jac_prototype - if !sense_drift.discrete || jac_prototype === nothing - adjoint_jac_prototype = nothing - else - _adjoint_jac_prototype = copy(jac_prototype') - zzz = similar(_adjoint_jac_prototype, numstates, numparams) - fill!(zzz, zero(eltype(zzz))) - II = Diagonal(I, numparams) - adjoint_jac_prototype = [_adjoint_jac_prototype zzz - copy(zzz') II] - end - - sdefun = SDEFunction(sense_drift,sense_diffusion,mass_matrix=mm,jac_prototype=adjoint_jac_prototype) - - # replicated noise - _sol = deepcopy(sol) - backwardnoise = reverse(_sol.W) - - if StochasticDiffEq.is_diagonal_noise(sol.prob) && typeof(sol.W[end])<:Number - # scalar noise case - noise_matrix = nothing - else - noise_matrix = similar(z0,length(z0),numstates) - noise_matrix .= false - end - - return SDEProblem(sdefun,sense_diffusion,z0,tspan,p, - callback=cb, - noise=backwardnoise, - noise_rate_prototype = noise_matrix - ) -end + jac_prototype = sol.prob.f.jac_prototype + if !sense_drift.discrete || jac_prototype === nothing + adjoint_jac_prototype = nothing + else + _adjoint_jac_prototype = copy(jac_prototype') + zzz = similar(_adjoint_jac_prototype, numstates, numparams) + fill!(zzz, zero(eltype(zzz))) + II = Diagonal(I, numparams) + adjoint_jac_prototype = [_adjoint_jac_prototype zzz + copy(zzz') II] + end + + sdefun = SDEFunction(sense_drift, sense_diffusion, mass_matrix = mm, + jac_prototype = adjoint_jac_prototype) -@noinline function RODEAdjointProblem(sol,sensealg::InterpolatingAdjoint, - t=nothing, - dg_discrete::DG1=nothing,dg_continuous::DG2=nothing, - g::G=nothing; - checkpoints=sol.t, - callback=CallbackSet(), - reltol=nothing, abstol=nothing, - kwargs...) where {DG1,DG2,G} - @unpack f, p, u0, tspan = sol.prob - tspan = reverse(tspan) - discrete = (t !== nothing && dg_continuous === nothing) - - # remove duplicates from checkpoints - if ischeckpointing(sensealg,sol) && (length(unique(checkpoints)) != length(checkpoints)) - _checkpoints, duplicate_iterator_times = separate_nonunique(checkpoints) - tstops = duplicate_iterator_times[1] - checkpoints = filter(x->x ∉ tstops, _checkpoints) - # check if start is in checkpoints. Otherwise first interval is missed. - if checkpoints[1] != tspan[2] - pushfirst!(checkpoints,tspan[2]) + # replicated noise + _sol = deepcopy(sol) + backwardnoise = reverse(_sol.W) + + if StochasticDiffEq.is_diagonal_noise(sol.prob) && typeof(sol.W[end]) <: Number + # scalar noise case + noise_matrix = nothing + else + noise_matrix = similar(z0, length(z0), numstates) + noise_matrix .= false end - else - tstops = nothing - end - - numstates = length(u0) - numparams = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(p) - - len = numstates+numparams - - λ = p === nothing || p === DiffEqBase.NullParameters() ? similar(u0) : one(eltype(u0)) .* similar(p, len) - λ .= false - - sense = ODEInterpolatingAdjointSensitivityFunction(g,sensealg,discrete,sol,dg_continuous,f, - checkpoints, - (reltol=reltol,abstol=abstol), - tstops) - - init_cb = (discrete || dg_discrete!==nothing) # && tspan[1] == t[end] - cb, duplicate_iterator_times = generate_callbacks(sense, dg_discrete, λ, t, tspan[2], callback, init_cb) - z0 = vec(zero(λ)) - original_mm = sol.prob.f.mass_matrix - if original_mm === I || original_mm === (I,I) - mm = I - else - adjmm = copy(sol.prob.f.mass_matrix') - zzz = similar(adjmm, numstates, numparams) - fill!(zzz, zero(eltype(zzz))) - # using concrate I is slightly more efficient - II = Diagonal(I, numparams) - mm = [adjmm zzz - copy(zzz') II] - end - - jac_prototype = sol.prob.f.jac_prototype - if !sense.discrete || jac_prototype === nothing - adjoint_jac_prototype = nothing - else - _adjoint_jac_prototype = copy(jac_prototype') - zzz = similar(_adjoint_jac_prototype, numstates, numparams) - fill!(zzz, zero(eltype(zzz))) - II = Diagonal(I, numparams) - adjoint_jac_prototype = [_adjoint_jac_prototype zzz - copy(zzz') II] - end - - rodefun = RODEFunction(sense, mass_matrix=mm, jac_prototype=adjoint_jac_prototype) - - # replicated noise - _sol = deepcopy(sol) - backwardnoise = reverse(_sol.W) - # make sure noise grid starts at correct time values, e.g., if sol.W.t is longer than sol.t - tspan[1]!=backwardnoise.t[1] && reinit!(backwardnoise,backwardnoise.t[2]-backwardnoise.t[1],t0=tspan[1]) - - return RODEProblem(rodefun,z0,tspan,p,callback=cb, - noise=backwardnoise) + + return SDEProblem(sdefun, sense_diffusion, z0, tspan, p, + callback = cb, + noise = backwardnoise, + noise_rate_prototype = noise_matrix) end +@noinline function RODEAdjointProblem(sol, sensealg::InterpolatingAdjoint, + t = nothing, + dg_discrete::DG1 = nothing, + dg_continuous::DG2 = nothing, + g::G = nothing; + checkpoints = sol.t, + callback = CallbackSet(), + reltol = nothing, abstol = nothing, + kwargs...) where {DG1, DG2, G} + @unpack f, p, u0, tspan = sol.prob + tspan = reverse(tspan) + discrete = (t !== nothing && dg_continuous === nothing) + + # remove duplicates from checkpoints + if ischeckpointing(sensealg, sol) && + (length(unique(checkpoints)) != length(checkpoints)) + _checkpoints, duplicate_iterator_times = separate_nonunique(checkpoints) + tstops = duplicate_iterator_times[1] + checkpoints = filter(x -> x ∉ tstops, _checkpoints) + # check if start is in checkpoints. Otherwise first interval is missed. + if checkpoints[1] != tspan[2] + pushfirst!(checkpoints, tspan[2]) + end + else + tstops = nothing + end -function reset_p(CBS, interval) - # check which events are close to tspan[1] - if !isempty(CBS.discrete_callbacks) - ts = map(CBS.discrete_callbacks) do cb - indx = searchsortedfirst(cb.affect!.event_times, interval[1]) - (indx, cb.affect!.event_times[indx]) + numstates = length(u0) + numparams = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(p) + + len = numstates + numparams + + λ = p === nothing || p === DiffEqBase.NullParameters() ? similar(u0) : + one(eltype(u0)) .* similar(p, len) + λ .= false + + sense = ODEInterpolatingAdjointSensitivityFunction(g, sensealg, discrete, sol, + dg_continuous, f, + checkpoints, + (reltol = reltol, abstol = abstol), + tstops) + + init_cb = (discrete || dg_discrete !== nothing) # && tspan[1] == t[end] + cb, duplicate_iterator_times = generate_callbacks(sense, dg_discrete, λ, t, tspan[2], + callback, init_cb) + z0 = vec(zero(λ)) + original_mm = sol.prob.f.mass_matrix + if original_mm === I || original_mm === (I, I) + mm = I + else + adjmm = copy(sol.prob.f.mass_matrix') + zzz = similar(adjmm, numstates, numparams) + fill!(zzz, zero(eltype(zzz))) + # using concrate I is slightly more efficient + II = Diagonal(I, numparams) + mm = [adjmm zzz + copy(zzz') II] end - perm = minimum(sortperm([t for t in getindex.(ts,2)])) - end - - if !isempty(CBS.continuous_callbacks) - ts2 = map(CBS.continuous_callbacks) do cb - if !isempty(cb.affect!.event_times) && isempty(cb.affect_neg!.event_times) - indx = searchsortedfirst(cb.affect!.event_times, interval[1]) - return (indx, cb.affect!.event_times[indx],0) # zero for affect! - elseif isempty(cb.affect!.event_times) && !isempty(cb.affect_neg!.event_times) - indx = searchsortedfirst(cb.affect_neg!.event_times, interval[1]) - return (indx, cb.affect_neg!.event_times[indx],1) # one for affect_neg! - elseif !isempty(cb.affect!.event_times) && !isempty(cb.affect_neg!.event_times) - indx1 = searchsortedfirst(cb.affect!.event_times, interval[1]) - indx2 = searchsortedfirst(cb.affect_neg!.event_times, interval[1]) - if cb.affect!.event_times[indx1] < cb.affect_neg!.event_times[indx2] - return (indx1, cb.affect!.event_times[indx1],0) - else - return (indx2, cb.affect_neg!.event_times[indx2],1) + + jac_prototype = sol.prob.f.jac_prototype + if !sense.discrete || jac_prototype === nothing + adjoint_jac_prototype = nothing + else + _adjoint_jac_prototype = copy(jac_prototype') + zzz = similar(_adjoint_jac_prototype, numstates, numparams) + fill!(zzz, zero(eltype(zzz))) + II = Diagonal(I, numparams) + adjoint_jac_prototype = [_adjoint_jac_prototype zzz + copy(zzz') II] + end + + rodefun = RODEFunction(sense, mass_matrix = mm, jac_prototype = adjoint_jac_prototype) + + # replicated noise + _sol = deepcopy(sol) + backwardnoise = reverse(_sol.W) + # make sure noise grid starts at correct time values, e.g., if sol.W.t is longer than sol.t + tspan[1] != backwardnoise.t[1] && + reinit!(backwardnoise, backwardnoise.t[2] - backwardnoise.t[1], t0 = tspan[1]) + + return RODEProblem(rodefun, z0, tspan, p, callback = cb, + noise = backwardnoise) +end + +function reset_p(CBS, interval) + # check which events are close to tspan[1] + if !isempty(CBS.discrete_callbacks) + ts = map(CBS.discrete_callbacks) do cb + indx = searchsortedfirst(cb.affect!.event_times, interval[1]) + (indx, cb.affect!.event_times[indx]) end - else - error("Expected event but reset_p couldn't find event time. Please report this error.") - end + perm = minimum(sortperm([t for t in getindex.(ts, 2)])) end - perm2 = minimum(sortperm([t for t in getindex.(ts2,2)])) - # check if continuous or discrete callback was applied first if both occur in interval - if isempty(CBS.discrete_callbacks) - if ts2[perm2][3] == 0 - p = deepcopy(CBS.continuous_callbacks[perm2].affect!.pleft[getindex.(ts2,1)[perm2]]) - else - p = deepcopy(CBS.continuous_callbacks[perm2].affect_neg!.pleft[getindex.(ts2,1)[perm2]]) - end - else - if ts[perm][2] < ts2[perm2][2] - p = deepcopy(CBS.discrete_callbacks[perm].affect!.pleft[getindex.(ts,1)[perm]]) - else - if ts2[perm2][3] == 0 - p = deepcopy(CBS.continuous_callbacks[perm2].affect!.pleft[getindex.(ts2,1)[perm2]]) + + if !isempty(CBS.continuous_callbacks) + ts2 = map(CBS.continuous_callbacks) do cb + if !isempty(cb.affect!.event_times) && isempty(cb.affect_neg!.event_times) + indx = searchsortedfirst(cb.affect!.event_times, interval[1]) + return (indx, cb.affect!.event_times[indx], 0) # zero for affect! + elseif isempty(cb.affect!.event_times) && !isempty(cb.affect_neg!.event_times) + indx = searchsortedfirst(cb.affect_neg!.event_times, interval[1]) + return (indx, cb.affect_neg!.event_times[indx], 1) # one for affect_neg! + elseif !isempty(cb.affect!.event_times) && !isempty(cb.affect_neg!.event_times) + indx1 = searchsortedfirst(cb.affect!.event_times, interval[1]) + indx2 = searchsortedfirst(cb.affect_neg!.event_times, interval[1]) + if cb.affect!.event_times[indx1] < cb.affect_neg!.event_times[indx2] + return (indx1, cb.affect!.event_times[indx1], 0) + else + return (indx2, cb.affect_neg!.event_times[indx2], 1) + end + else + error("Expected event but reset_p couldn't find event time. Please report this error.") + end + end + perm2 = minimum(sortperm([t for t in getindex.(ts2, 2)])) + # check if continuous or discrete callback was applied first if both occur in interval + if isempty(CBS.discrete_callbacks) + if ts2[perm2][3] == 0 + p = deepcopy(CBS.continuous_callbacks[perm2].affect!.pleft[getindex.(ts2, 1)[perm2]]) + else + p = deepcopy(CBS.continuous_callbacks[perm2].affect_neg!.pleft[getindex.(ts2, + 1)[perm2]]) + end else - p = deepcopy(CBS.continuous_callbacks[perm2].affect_neg!.pleft[getindex.(ts2,1)[perm2]]) + if ts[perm][2] < ts2[perm2][2] + p = deepcopy(CBS.discrete_callbacks[perm].affect!.pleft[getindex.(ts, 1)[perm]]) + else + if ts2[perm2][3] == 0 + p = deepcopy(CBS.continuous_callbacks[perm2].affect!.pleft[getindex.(ts2, + 1)[perm2]]) + else + p = deepcopy(CBS.continuous_callbacks[perm2].affect_neg!.pleft[getindex.(ts2, + 1)[perm2]]) + end + end end - end + else + p = deepcopy(CBS.discrete_callbacks[perm].affect!.pleft[getindex.(ts, 1)[perm]]) end - else - p = deepcopy(CBS.discrete_callbacks[perm].affect!.pleft[getindex.(ts,1)[perm]]) - end - return p + return p end diff --git a/src/lss.jl b/src/lss.jl index 84843a9e7..635eb2414 100644 --- a/src/lss.jl +++ b/src/lss.jl @@ -1,643 +1,662 @@ -struct LSSSchur{wBType,wEType,BType,EType} - wBinv::wBType - wEinv::wEType - B::BType - E::EType +struct LSSSchur{wBType, wEType, BType, EType} + wBinv::wBType + wEinv::wEType + B::BType + E::EType end -struct LSSSensitivityFunction{iip,F,A,J,JP,S,PJ,UF,PF,JC,PJC,Alg,fc,JM,pJM,MM,CV, - PGPU,PGPP,CONFU,CONGP,DG} <: DiffEqBase.AbstractODEFunction{iip} - f::F - analytic::A - jac::J - jac_prototype::JP - sparsity::S - paramjac::PJ - uf::UF - pf::PF - J::JM - pJ::pJM - jac_config::JC - paramjac_config::PJC - alg::Alg - numparams::Int - numindvar::Int - f_cache::fc - mass_matrix::MM - colorvec::CV - pgpu::PGPU - pgpp::PGPP - pgpu_config::CONFU - pgpp_config::CONGP - dg_val::DG +struct LSSSensitivityFunction{iip, F, A, J, JP, S, PJ, UF, PF, JC, PJC, Alg, fc, JM, pJM, + MM, CV, + PGPU, PGPP, CONFU, CONGP, DG} <: + DiffEqBase.AbstractODEFunction{iip} + f::F + analytic::A + jac::J + jac_prototype::JP + sparsity::S + paramjac::PJ + uf::UF + pf::PF + J::JM + pJ::pJM + jac_config::JC + paramjac_config::PJC + alg::Alg + numparams::Int + numindvar::Int + f_cache::fc + mass_matrix::MM + colorvec::CV + pgpu::PGPU + pgpp::PGPP + pgpu_config::CONFU + pgpp_config::CONGP + dg_val::DG end -function LSSSensitivityFunction(sensealg,f,analytic,jac,jac_prototype,sparsity,paramjac,u0, - alg,p,f_cache,mm, - colorvec,tspan,g,dg) - - uf = DiffEqBase.UJacobianWrapper(f,tspan[1],p) - pf = DiffEqBase.ParamJacobianWrapper(f,tspan[1],copy(u0)) - - if DiffEqBase.has_jac(f) - jac_config = nothing - else - jac_config = build_jac_config(sensealg,uf,u0) - end - - if DiffEqBase.has_paramjac(f) - paramjac_config = nothing - else - paramjac_config = build_param_jac_config(sensealg,pf,u0,p) - end - numparams = length(p) - numindvar = length(u0) - J = Matrix{eltype(u0)}(undef,numindvar,numindvar) - pJ = Matrix{eltype(u0)}(undef,numindvar,numparams) # number of funcs size - - # compute gradients of objective - if dg !== nothing - pgpu = nothing - pgpp = nothing - pgpu_config = nothing - pgpp_config = nothing - if dg isa Tuple && length(dg) == 2 - dg_val = (similar(u0, numindvar),similar(u0, numparams)) - dg_val[1] .= false - dg_val[2] .= false +function LSSSensitivityFunction(sensealg, f, analytic, jac, jac_prototype, sparsity, + paramjac, u0, + alg, p, f_cache, mm, + colorvec, tspan, g, dg) + uf = DiffEqBase.UJacobianWrapper(f, tspan[1], p) + pf = DiffEqBase.ParamJacobianWrapper(f, tspan[1], copy(u0)) + + if DiffEqBase.has_jac(f) + jac_config = nothing else - dg_val = similar(u0, numindvar) # number of funcs size - dg_val .= false + jac_config = build_jac_config(sensealg, uf, u0) end - else - pgpu = UGradientWrapper(g,tspan[1],p) # ∂g∂u - pgpp = ParamGradientWrapper(g,tspan[1],u0) #∂g∂p - pgpu_config = build_grad_config(sensealg,pgpu,u0,tspan[1]) - pgpp_config = build_grad_config(sensealg,pgpp,p,tspan[1]) - dg_val = (similar(u0, numindvar),similar(u0, numparams)) - dg_val[1] .= false - dg_val[2] .= false - end - - LSSSensitivityFunction{isinplace(f),typeof(f),typeof(analytic), - typeof(jac),typeof(jac_prototype),typeof(sparsity), - typeof(paramjac), - typeof(uf), - typeof(pf),typeof(jac_config), - typeof(paramjac_config),typeof(alg), - typeof(f_cache), - typeof(J),typeof(pJ),typeof(mm),typeof(f.colorvec), - typeof(pgpu),typeof(pgpp),typeof(pgpu_config),typeof(pgpp_config),typeof(dg_val)}( - f,analytic,jac,jac_prototype,sparsity,paramjac,uf,pf,J,pJ, - jac_config,paramjac_config,alg, - numparams,numindvar,f_cache,mm,colorvec, - pgpu,pgpp,pgpu_config,pgpp_config,dg_val) -end + if DiffEqBase.has_paramjac(f) + paramjac_config = nothing + else + paramjac_config = build_param_jac_config(sensealg, pf, u0, p) + end + numparams = length(p) + numindvar = length(u0) + J = Matrix{eltype(u0)}(undef, numindvar, numindvar) + pJ = Matrix{eltype(u0)}(undef, numindvar, numparams) # number of funcs size + + # compute gradients of objective + if dg !== nothing + pgpu = nothing + pgpp = nothing + pgpu_config = nothing + pgpp_config = nothing + if dg isa Tuple && length(dg) == 2 + dg_val = (similar(u0, numindvar), similar(u0, numparams)) + dg_val[1] .= false + dg_val[2] .= false + else + dg_val = similar(u0, numindvar) # number of funcs size + dg_val .= false + end + else + pgpu = UGradientWrapper(g, tspan[1], p) # ∂g∂u + pgpp = ParamGradientWrapper(g, tspan[1], u0) #∂g∂p + pgpu_config = build_grad_config(sensealg, pgpu, u0, tspan[1]) + pgpp_config = build_grad_config(sensealg, pgpp, p, tspan[1]) + dg_val = (similar(u0, numindvar), similar(u0, numparams)) + dg_val[1] .= false + dg_val[2] .= false + end -struct ForwardLSSProblem{A,C,solType,dtType,umidType,dudtType,SType,Ftype,bType,ηType,wType,vType,windowType, - ΔtType,G0,G,DG,resType} - sensealg::A - diffcache::C - sol::solType - dt::dtType - umid::umidType - dudt::dudtType - S::SType - F::Ftype - b::bType - η::ηType - w::wType - v::vType - window::windowType - Δt::ΔtType - Nt::Int - g0::G0 - g::G - dg::DG - res::resType + LSSSensitivityFunction{isinplace(f), typeof(f), typeof(analytic), + typeof(jac), typeof(jac_prototype), typeof(sparsity), + typeof(paramjac), + typeof(uf), + typeof(pf), typeof(jac_config), + typeof(paramjac_config), typeof(alg), + typeof(f_cache), + typeof(J), typeof(pJ), typeof(mm), typeof(f.colorvec), + typeof(pgpu), typeof(pgpp), typeof(pgpu_config), + typeof(pgpp_config), typeof(dg_val)}(f, analytic, jac, + jac_prototype, sparsity, + paramjac, uf, pf, J, pJ, + jac_config, paramjac_config, + alg, + numparams, numindvar, + f_cache, mm, colorvec, + pgpu, pgpp, pgpu_config, + pgpp_config, dg_val) end +struct ForwardLSSProblem{A, C, solType, dtType, umidType, dudtType, SType, Ftype, bType, + ηType, wType, vType, windowType, + ΔtType, G0, G, DG, resType} + sensealg::A + diffcache::C + sol::solType + dt::dtType + umid::umidType + dudt::dudtType + S::SType + F::Ftype + b::bType + η::ηType + w::wType + v::vType + window::windowType + Δt::ΔtType + Nt::Int + g0::G0 + g::G + dg::DG + res::resType +end function ForwardLSSProblem(sol, sensealg::ForwardLSS; - t=nothing, dg_discrete = nothing, dg_continuous = nothing, - kwargs...) - - @unpack f, p, u0, tspan = sol.prob - @unpack g = sensealg - - isinplace = DiffEqBase.isinplace(f) - - # some shadowing sensealgs require knowledge of g - check_for_g(sensealg,g) - - p === nothing && error("You must have parameters to use parameter sensitivity calculations!") - !(sol.u isa AbstractVector) && error("`u` has to be an AbstractVector.") - - # assert that all ts are hit if concrete solve interface/discrete costs are used - if t !== nothing - @assert sol.t == t - dg = dg_discrete - else - dg = dg_continuous - end - - sense = LSSSensitivityFunction(sensealg,f,f.analytic,f.jac, - f.jac_prototype,f.sparsity,f.paramjac, - u0,sensealg, - p,similar(u0),f.mass_matrix, - f.colorvec, - tspan,g,dg) - - @unpack numparams, numindvar = sense - Nt = length(sol.t) - Ndt = Nt-one(Nt) - - # pre-allocate variables - dt = similar(sol.t, Ndt) - umid = Matrix{eltype(u0)}(undef,numindvar,Ndt) - dudt = Matrix{eltype(u0)}(undef,numindvar,Ndt) - # compute their values - discretize_ref_trajectory!(dt, umid, dudt, sol, Ndt) - - S = LSSSchur(dt,u0,numindvar,Nt,Ndt,sensealg.LSSregularizer) - - if sensealg.LSSregularizer isa TimeDilation - η = similar(dt,Ndt) - window = nothing - g0 = g(u0,p,tspan[1]) - else - η = nothing - window = similar(dt,Nt) - g0 = nothing - end - - b = Matrix{eltype(u0)}(undef,numindvar*Ndt,numparams) - w = similar(dt,numindvar*Ndt) - v = similar(dt,numindvar*Nt) - - Δt = tspan[2] - tspan[1] - wB!(S,Δt,Nt,numindvar,dt) - wE!(S,Δt,dt,sensealg.LSSregularizer) - B!(S,dt,umid,sense,sensealg) - E!(S,dudt,sensealg.LSSregularizer) - - F = SchurLU(S) - - res = similar(u0, numparams) - - ForwardLSSProblem{typeof(sensealg),typeof(sense),typeof(sol),typeof(dt), - typeof(umid),typeof(dudt), - typeof(S),typeof(F),typeof(b),typeof(η),typeof(w),typeof(v),typeof(window),typeof(Δt), - typeof(g0),typeof(g),typeof(dg),typeof(res)}(sensealg,sense,sol,dt,umid,dudt,S,F,b,η,w,v, - window,Δt,Nt,g0,g,dg,res) + t = nothing, dg_discrete = nothing, dg_continuous = nothing, + kwargs...) + @unpack f, p, u0, tspan = sol.prob + @unpack g = sensealg + + isinplace = DiffEqBase.isinplace(f) + + # some shadowing sensealgs require knowledge of g + check_for_g(sensealg, g) + + p === nothing && + error("You must have parameters to use parameter sensitivity calculations!") + !(sol.u isa AbstractVector) && error("`u` has to be an AbstractVector.") + + # assert that all ts are hit if concrete solve interface/discrete costs are used + if t !== nothing + @assert sol.t == t + dg = dg_discrete + else + dg = dg_continuous + end + + sense = LSSSensitivityFunction(sensealg, f, f.analytic, f.jac, + f.jac_prototype, f.sparsity, f.paramjac, + u0, sensealg, + p, similar(u0), f.mass_matrix, + f.colorvec, + tspan, g, dg) + + @unpack numparams, numindvar = sense + Nt = length(sol.t) + Ndt = Nt - one(Nt) + + # pre-allocate variables + dt = similar(sol.t, Ndt) + umid = Matrix{eltype(u0)}(undef, numindvar, Ndt) + dudt = Matrix{eltype(u0)}(undef, numindvar, Ndt) + # compute their values + discretize_ref_trajectory!(dt, umid, dudt, sol, Ndt) + + S = LSSSchur(dt, u0, numindvar, Nt, Ndt, sensealg.LSSregularizer) + + if sensealg.LSSregularizer isa TimeDilation + η = similar(dt, Ndt) + window = nothing + g0 = g(u0, p, tspan[1]) + else + η = nothing + window = similar(dt, Nt) + g0 = nothing + end + + b = Matrix{eltype(u0)}(undef, numindvar * Ndt, numparams) + w = similar(dt, numindvar * Ndt) + v = similar(dt, numindvar * Nt) + + Δt = tspan[2] - tspan[1] + wB!(S, Δt, Nt, numindvar, dt) + wE!(S, Δt, dt, sensealg.LSSregularizer) + B!(S, dt, umid, sense, sensealg) + E!(S, dudt, sensealg.LSSregularizer) + + F = SchurLU(S) + + res = similar(u0, numparams) + + ForwardLSSProblem{typeof(sensealg), typeof(sense), typeof(sol), typeof(dt), + typeof(umid), typeof(dudt), + typeof(S), typeof(F), typeof(b), typeof(η), typeof(w), typeof(v), + typeof(window), typeof(Δt), + typeof(g0), typeof(g), typeof(dg), typeof(res)}(sensealg, sense, sol, + dt, umid, dudt, S, F, + b, η, w, v, + window, Δt, Nt, g0, g, + dg, res) end -function LSSSchur(dt,u0,numindvar,Nt,Ndt,LSSregularizer::TimeDilation) - wBinv = similar(dt,numindvar*Nt) - wEinv = similar(dt,Ndt) - E = Matrix{eltype(u0)}(undef,numindvar*Ndt,Ndt) - B = Matrix{eltype(u0)}(undef,numindvar*Ndt,numindvar*Nt) +function LSSSchur(dt, u0, numindvar, Nt, Ndt, LSSregularizer::TimeDilation) + wBinv = similar(dt, numindvar * Nt) + wEinv = similar(dt, Ndt) + E = Matrix{eltype(u0)}(undef, numindvar * Ndt, Ndt) + B = Matrix{eltype(u0)}(undef, numindvar * Ndt, numindvar * Nt) - LSSSchur(wBinv,wEinv,B,E) + LSSSchur(wBinv, wEinv, B, E) end -function LSSSchur(dt,u0,numindvar,Nt,Ndt,LSSregularizer::AbstractCosWindowing) - wBinv = similar(dt,numindvar*Nt) - wEinv = nothing - E = nothing - B = Matrix{eltype(u0)}(undef,numindvar*Ndt,numindvar*Nt) +function LSSSchur(dt, u0, numindvar, Nt, Ndt, LSSregularizer::AbstractCosWindowing) + wBinv = similar(dt, numindvar * Nt) + wEinv = nothing + E = nothing + B = Matrix{eltype(u0)}(undef, numindvar * Ndt, numindvar * Nt) - LSSSchur(wBinv,wEinv,B,E) + LSSSchur(wBinv, wEinv, B, E) end # compute discretized reference trajectory function discretize_ref_trajectory!(dt, umid, dudt, sol, Ndt) - for i=1:Ndt - tr = sol.t[i+1] - tl = sol.t[i] - ur = sol.u[i+1] - ul = sol.u[i] - dt[i] = tr-tl - copyto!((@view umid[:,i]), (ur + ul)/2) - copyto!((@view dudt[:,i]), (ur - ul)/dt[i]) - end - return nothing + for i in 1:Ndt + tr = sol.t[i + 1] + tl = sol.t[i] + ur = sol.u[i + 1] + ul = sol.u[i] + dt[i] = tr - tl + copyto!((@view umid[:, i]), (ur + ul) / 2) + copyto!((@view dudt[:, i]), (ur - ul) / dt[i]) + end + return nothing end -function wB!(S::LSSSchur,Δt,Nt,numindvar,dt) - @unpack wBinv = S - fill!(wBinv, one(Δt)) - dim = numindvar * Nt - tmp = @view wBinv[1:numindvar] - tmp ./= dt[1] - tmp = @view wBinv[dim-2:end] - tmp ./= dt[end] - for indx = 2:Nt-1 - tmp = @view wBinv[(indx-1)*numindvar+1:indx*numindvar] - tmp ./= (dt[indx]+dt[indx-1]) - end - - wBinv .*= 2*Δt - return nothing +function wB!(S::LSSSchur, Δt, Nt, numindvar, dt) + @unpack wBinv = S + fill!(wBinv, one(Δt)) + dim = numindvar * Nt + tmp = @view wBinv[1:numindvar] + tmp ./= dt[1] + tmp = @view wBinv[(dim - 2):end] + tmp ./= dt[end] + for indx in 2:(Nt - 1) + tmp = @view wBinv[((indx - 1) * numindvar + 1):(indx * numindvar)] + tmp ./= (dt[indx] + dt[indx - 1]) + end + + wBinv .*= 2 * Δt + return nothing end -wE!(S::LSSSchur,Δt,dt,LSSregularizer::AbstractCosWindowing) = nothing +wE!(S::LSSSchur, Δt, dt, LSSregularizer::AbstractCosWindowing) = nothing -function wE!(S::LSSSchur,Δt,dt,LSSregularizer::TimeDilation) - @unpack wEinv = S - @unpack alpha = LSSregularizer - @. wEinv = Δt/(alpha^2*dt) - return nothing +function wE!(S::LSSSchur, Δt, dt, LSSregularizer::TimeDilation) + @unpack wEinv = S + @unpack alpha = LSSregularizer + @. wEinv = Δt / (alpha^2 * dt) + return nothing end -function B!(S::LSSSchur,dt,umid,sense,sensealg) - @unpack B = S - @unpack f,J,uf,numindvar,f_cache,jac_config = sense +function B!(S::LSSSchur, dt, umid, sense, sensealg) + @unpack B = S + @unpack f, J, uf, numindvar, f_cache, jac_config = sense - fill!(B, zero(eltype(J))) + fill!(B, zero(eltype(J))) - for (i,u) in enumerate(eachcol(umid)) - if DiffEqBase.has_jac(f) - f.jac(J,u,uf.p,uf.t) # Calculate the Jacobian into J - else - jacobian!(J, uf, u, f_cache, sensealg, jac_config) + for (i, u) in enumerate(eachcol(umid)) + if DiffEqBase.has_jac(f) + f.jac(J, u, uf.p, uf.t) # Calculate the Jacobian into J + else + jacobian!(J, uf, u, f_cache, sensealg, jac_config) + end + B0 = @view B[((i - 1) * numindvar + 1):(i * numindvar), + (i * numindvar + 1):((i + 1) * numindvar)] + B1 = @view B[((i - 1) * numindvar + 1):(i * numindvar), + ((i - 1) * numindvar + 1):(i * numindvar)] + B0 .+= I / dt[i] - J / 2 + B1 .+= -I / dt[i] - J / 2 end - B0 = @view B[(i-1)*numindvar+1:i*numindvar,i*numindvar+1:(i+1)*numindvar] - B1 = @view B[(i-1)*numindvar+1:i*numindvar,(i-1)*numindvar+1:i*numindvar] - B0 .+= I/dt[i] - J/2 - B1 .+= -I/dt[i] -J/2 - end - return nothing + return nothing end -E!(S::LSSSchur,dudt,LSSregularizer::AbstractCosWindowing) = nothing +E!(S::LSSSchur, dudt, LSSregularizer::AbstractCosWindowing) = nothing -function E!(S::LSSSchur,dudt,LSSregularizer::TimeDilation) - @unpack E = S - numindvar, Ndt = size(dudt) - for i=1:Ndt - tmp = @view E[(i-1)*numindvar+1:i*numindvar,i] - copyto!(tmp, (@view dudt[:,i])) - end - return nothing +function E!(S::LSSSchur, dudt, LSSregularizer::TimeDilation) + @unpack E = S + numindvar, Ndt = size(dudt) + for i in 1:Ndt + tmp = @view E[((i - 1) * numindvar + 1):(i * numindvar), i] + copyto!(tmp, (@view dudt[:, i])) + end + return nothing end # compute Schur function SchurLU(S::LSSSchur) - @unpack B, E, wBinv, wEinv = S - Smat = B*Diagonal(wBinv)*B' - (wEinv !== nothing) && (Smat .+= E*Diagonal(wEinv)*E') - F = lu!(Smat) - return F + @unpack B, E, wBinv, wEinv = S + Smat = B * Diagonal(wBinv) * B' + (wEinv !== nothing) && (Smat .+= E * Diagonal(wEinv) * E') + F = lu!(Smat) + return F end function b!(b, prob::ForwardLSSProblem) - @unpack diffcache, umid, sensealg = prob - @unpack f, f_cache, pJ, pf, paramjac_config, uf, numindvar = diffcache + @unpack diffcache, umid, sensealg = prob + @unpack f, f_cache, pJ, pf, paramjac_config, uf, numindvar = diffcache - for (i,u) in enumerate(eachcol(umid)) - if DiffEqBase.has_paramjac(f) - f.paramjac(pJ, u, uf.p, pf.t) - else - pf.u = u - jacobian!(pJ, pf, uf.p, f_cache, sensealg, paramjac_config) + for (i, u) in enumerate(eachcol(umid)) + if DiffEqBase.has_paramjac(f) + f.paramjac(pJ, u, uf.p, pf.t) + else + pf.u = u + jacobian!(pJ, pf, uf.p, f_cache, sensealg, paramjac_config) + end + tmp = @view b[((i - 1) * numindvar + 1):(i * numindvar), :] + copyto!(tmp, pJ) end - tmp = @view b[(i-1)*numindvar+1:i*numindvar,:] - copyto!(tmp, pJ) - end - return nothing + return nothing end -function shadow_forward(prob::ForwardLSSProblem; sensealg=prob.sensealg) - shadow_forward(prob,sensealg,sensealg.LSSregularizer) +function shadow_forward(prob::ForwardLSSProblem; sensealg = prob.sensealg) + shadow_forward(prob, sensealg, sensealg.LSSregularizer) end -function shadow_forward(prob::ForwardLSSProblem,sensealg::ForwardLSS,LSSregularizer::TimeDilation) - @unpack sol, S, F, window, Δt, diffcache, b, w, v, η, res, g, g0, dg, umid = prob - @unpack wBinv, wEinv, B, E = S - @unpack dg_val, numparams, numindvar, uf = diffcache - @unpack t0skip, t1skip = LSSregularizer - - n0 = searchsortedfirst(sol.t, sol.t[1]+t0skip) - n1 = searchsortedfirst(sol.t, sol.t[end]-t1skip) +function shadow_forward(prob::ForwardLSSProblem, sensealg::ForwardLSS, + LSSregularizer::TimeDilation) + @unpack sol, S, F, window, Δt, diffcache, b, w, v, η, res, g, g0, dg, umid = prob + @unpack wBinv, wEinv, B, E = S + @unpack dg_val, numparams, numindvar, uf = diffcache + @unpack t0skip, t1skip = LSSregularizer - b!(b,prob) + n0 = searchsortedfirst(sol.t, sol.t[1] + t0skip) + n1 = searchsortedfirst(sol.t, sol.t[end] - t1skip) - ures = @view sol.u[n0:n1] - umidres = @view umid[:,n0:n1-1] + b!(b, prob) - # reset - res .*=false + ures = @view sol.u[n0:n1] + umidres = @view umid[:, n0:(n1 - 1)] - for i=1:numparams - #running average - g0 *= false - bpar = @view b[:,i] - w .= F\bpar - v .= Diagonal(wBinv)*(B'*w) - η .= Diagonal(wEinv)*(E'*w) + # reset + res .*= false - ηres = @view η[n0:n1-1] + for i in 1:numparams + #running average + g0 *= false + bpar = @view b[:, i] + w .= F \ bpar + v .= Diagonal(wBinv) * (B' * w) + η .= Diagonal(wEinv) * (E' * w) + + ηres = @view η[n0:(n1 - 1)] + + for (j, u) in enumerate(ures) + vtmp = @view v[((n0 + j - 2) * numindvar + 1):((n0 + j - 1) * numindvar)] + # final gradient result for ith parameter + accumulate_cost!(dg, u, uf.p, uf.t, sensealg, diffcache, n0 + j - 1) + + if dg_val isa Tuple + res[i] += dot(dg_val[1], vtmp) + res[i] += dg_val[2][i] + else + res[i] += dot(dg_val, vtmp) + end + end + # mean value + res[i] = res[i] / (n1 - n0 + 1) + + for (j, u) in enumerate(eachcol(umidres)) + # compute objective + gtmp = g(u, uf.p, nothing) + g0 += gtmp + res[i] -= ηres[j] * gtmp / (n1 - n0) + end + res[i] = res[i] + sum(ηres) * g0 / (n1 - n0)^2 + end + return res +end - for (j, u) in enumerate(ures) - vtmp = @view v[(n0+j-2)*numindvar+1:(n0+j-1)*numindvar] - # final gradient result for ith parameter - accumulate_cost!(dg, u, uf.p, uf.t, sensealg, diffcache, n0+j-1) +function shadow_forward(prob::ForwardLSSProblem, sensealg::ForwardLSS, + LSSregularizer::CosWindowing) + @unpack sol, S, F, window, Δt, diffcache, b, w, v, dg, res = prob + @unpack wBinv, B = S + @unpack dg_val, numparams, numindvar, uf = diffcache - if dg_val isa Tuple - res[i] += dot(dg_val[1], vtmp) - res[i] += dg_val[2][i] - else - res[i] += dot(dg_val, vtmp) - end + b!(b, prob) - end - # mean value - res[i] = res[i]/(n1-n0+1) - - for (j,u) in enumerate(eachcol(umidres)) - # compute objective - gtmp = g(u,uf.p,nothing) - g0 += gtmp - res[i] -= ηres[j]*gtmp/(n1-n0) - end - res[i] = res[i] + sum(ηres)*g0/(n1-n0)^2 + # windowing (cos) + @. window = (sol.t - sol.t[1]) * convert(eltype(Δt), 2 * pi / Δt) + @. window = one(eltype(window)) - cos(window) + window ./= sum(window) - end - return res -end + res .*= false -function shadow_forward(prob::ForwardLSSProblem,sensealg::ForwardLSS,LSSregularizer::CosWindowing) - @unpack sol, S, F, window, Δt, diffcache, b, w, v, dg, res = prob - @unpack wBinv, B = S - @unpack dg_val, numparams, numindvar, uf = diffcache - - b!(b,prob) - - # windowing (cos) - @. window = (sol.t-sol.t[1])*convert(eltype(Δt),2*pi/Δt) - @. window = one(eltype(window)) - cos(window) - window ./= sum(window) - - res .*= false - - for i=1:numparams - bpar = @view b[:,i] - w .= F\bpar - v .= Diagonal(wBinv)*(B'*w) - for (j,u) in enumerate(sol.u) - vtmp = @view v[(j-1)*numindvar+1:j*numindvar] - # final gradient result for ith parameter - accumulate_cost!(dg, u, uf.p, uf.t, sensealg, diffcache, j) - if dg_val isa Tuple - res[i] += dot(dg_val[1], vtmp) * window[j] - res[i] += dg_val[2][i] * window[j] - else - res[i] += dot(dg_val, vtmp) * window[j] - end + for i in 1:numparams + bpar = @view b[:, i] + w .= F \ bpar + v .= Diagonal(wBinv) * (B' * w) + for (j, u) in enumerate(sol.u) + vtmp = @view v[((j - 1) * numindvar + 1):(j * numindvar)] + # final gradient result for ith parameter + accumulate_cost!(dg, u, uf.p, uf.t, sensealg, diffcache, j) + if dg_val isa Tuple + res[i] += dot(dg_val[1], vtmp) * window[j] + res[i] += dg_val[2][i] * window[j] + else + res[i] += dot(dg_val, vtmp) * window[j] + end + end end - end - return res + return res end -function shadow_forward(prob::ForwardLSSProblem,sensealg::ForwardLSS,LSSregularizer::Cos2Windowing) +function shadow_forward(prob::ForwardLSSProblem, sensealg::ForwardLSS, + LSSregularizer::Cos2Windowing) @unpack sol, S, F, window, Δt, diffcache, b, w, v, dg, res = prob @unpack wBinv, B = S @unpack dg_val, numparams, numindvar, uf = diffcache - b!(b,prob) + b!(b, prob) res .*= false # windowing cos2 - @. window = (sol.t-sol.t[1])*convert(eltype(Δt),2*pi/Δt) + @. window = (sol.t - sol.t[1]) * convert(eltype(Δt), 2 * pi / Δt) @. window = (one(eltype(window)) - cos(window))^2 window ./= sum(window) - for i=1:numparams - bpar = @view b[:,i] - w .= F\bpar - v .= Diagonal(wBinv)*(B'*w) - for (j, u) in enumerate(sol.u) - vtmp = @view v[(j-1)*numindvar+1:j*numindvar] - # final gradient result for ith parameter - accumulate_cost!(dg, u, uf.p, uf.t, sensealg, diffcache, j) - if dg_val isa Tuple - res[i] += dot(dg_val[1], vtmp) * window[j] - res[i] += dg_val[2][i] * window[j] - else - res[i] += dot(dg_val, vtmp) * window[j] + for i in 1:numparams + bpar = @view b[:, i] + w .= F \ bpar + v .= Diagonal(wBinv) * (B' * w) + for (j, u) in enumerate(sol.u) + vtmp = @view v[((j - 1) * numindvar + 1):(j * numindvar)] + # final gradient result for ith parameter + accumulate_cost!(dg, u, uf.p, uf.t, sensealg, diffcache, j) + if dg_val isa Tuple + res[i] += dot(dg_val[1], vtmp) * window[j] + res[i] += dg_val[2][i] * window[j] + else + res[i] += dot(dg_val, vtmp) * window[j] + end end - end end return res end function accumulate_cost!(dg, u, p, t, sensealg::ForwardLSS, diffcache, indx) - @unpack dg_val, pgpu, pgpu_config, pgpp, pgpp_config, uf = diffcache + @unpack dg_val, pgpu, pgpu_config, pgpp, pgpp_config, uf = diffcache - if dg === nothing - if dg_val isa Tuple - SciMLSensitivity.gradient!(dg_val[1], pgpu, u, sensealg, pgpu_config) - SciMLSensitivity.gradient!(dg_val[2], pgpp, p, sensealg, pgpp_config) - else - SciMLSensitivity.gradient!(dg_val, pgpu, u, sensealg, pgpu_config) - end - else - if dg_val isa Tuple - dg[1](dg_val[1], u, p, nothing, indx) # indx = n0 + j - 1 for LSSregularizer and j for windowing - dg[2](dg_val[2], u, p, nothing, indx) + if dg === nothing + if dg_val isa Tuple + SciMLSensitivity.gradient!(dg_val[1], pgpu, u, sensealg, pgpu_config) + SciMLSensitivity.gradient!(dg_val[2], pgpp, p, sensealg, pgpp_config) + else + SciMLSensitivity.gradient!(dg_val, pgpu, u, sensealg, pgpu_config) + end else - dg(dg_val, u, p, nothing, indx) + if dg_val isa Tuple + dg[1](dg_val[1], u, p, nothing, indx) # indx = n0 + j - 1 for LSSregularizer and j for windowing + dg[2](dg_val[2], u, p, nothing, indx) + else + dg(dg_val, u, p, nothing, indx) + end end - end - return nothing + return nothing end -struct AdjointLSSProblem{A,C,solType,dtType,umidType,dudtType,SType,FType,hType,bType,wType, - ΔtType,G0,G,DG,resType} - sensealg::A - diffcache::C - sol::solType - dt::dtType - umid::umidType - dudt::dudtType - S::SType - F::FType - h::hType - b::bType - wa::wType - Δt::ΔtType - Nt::Int - g0::G0 - g::G - dg::DG - res::resType +struct AdjointLSSProblem{A, C, solType, dtType, umidType, dudtType, SType, FType, hType, + bType, wType, + ΔtType, G0, G, DG, resType} + sensealg::A + diffcache::C + sol::solType + dt::dtType + umid::umidType + dudt::dudtType + S::SType + F::FType + h::hType + b::bType + wa::wType + Δt::ΔtType + Nt::Int + g0::G0 + g::G + dg::DG + res::resType end - function AdjointLSSProblem(sol, sensealg::AdjointLSS; - t=nothing, dg_discrete = nothing, dg_continuous = nothing, - kwargs...) - - @unpack f, p, u0, tspan = sol.prob - @unpack g = sensealg - - isinplace = DiffEqBase.isinplace(f) - - # some shadowing sensealgs require knowledge of g - check_for_g(sensealg,g) - - p === nothing && error("You must have parameters to use parameter sensitivity calculations!") - !(sol.u isa AbstractVector) && error("`u` has to be an AbstractVector.") - - # assert that all ts are hit if concrete solve interface/discrete costs are used - if t !== nothing - @assert sol.t == t - dg = dg_discrete - else - dg = dg_continuous - end - - sense = LSSSensitivityFunction(sensealg,f,f.analytic,f.jac, - f.jac_prototype,f.sparsity,f.paramjac, - u0,sensealg, - p,similar(u0),f.mass_matrix, - f.colorvec, - tspan,g,dg) - - @unpack numparams, numindvar = sense - Nt = length(sol.t) - Ndt = Nt-one(Nt) - - # pre-allocate variables - dt = similar(sol.t, Ndt) - umid = Matrix{eltype(u0)}(undef,numindvar,Ndt) - dudt = Matrix{eltype(u0)}(undef,numindvar,Ndt) - # compute their values - discretize_ref_trajectory!(dt, umid, dudt, sol, Ndt) - - S = LSSSchur(dt,u0,numindvar,Nt,Ndt,sensealg.LSSregularizer) - - if sensealg.LSSregularizer isa TimeDilation - g0 = g(u0,p,tspan[1]) - else - g0 = nothing - end - - b = Vector{eltype(u0)}(undef,numindvar*Ndt) - h = Vector{eltype(u0)}(undef,Ndt) - wa = similar(dt,numindvar*Ndt) - - Δt = tspan[2] - tspan[1] - wB!(S,Δt,Nt,numindvar,dt) - wE!(S,Δt,dt,sensealg.LSSregularizer) - - B!(S,dt,umid,sense,sensealg) - E!(S,dudt,sensealg.LSSregularizer) - F = SchurLU(S) - wBcorrect!(S,sol,g,Nt,sense,sensealg,dg) - - h!(h,g0,g,umid,p,S.wEinv) - - res = similar(u0, numparams) - - AdjointLSSProblem{typeof(sensealg),typeof(sense),typeof(sol),typeof(dt), - typeof(umid),typeof(dudt), - typeof(S),typeof(F),typeof(h),typeof(b),typeof(wa),typeof(Δt), - typeof(g0),typeof(g),typeof(dg),typeof(res)}(sensealg,sense,sol,dt,umid,dudt,S,F,h,b,wa, - Δt,Nt,g0,g,dg,res) -end + t = nothing, dg_discrete = nothing, dg_continuous = nothing, + kwargs...) + @unpack f, p, u0, tspan = sol.prob + @unpack g = sensealg -function h!(h,g0,g,u,p,wEinv) + isinplace = DiffEqBase.isinplace(f) - for (j,uj) in enumerate(eachcol(u)) - # compute objective - h[j] = g(uj,p,nothing) - end - h .= -(h .- mean(h)) / (size(u)[2]) + # some shadowing sensealgs require knowledge of g + check_for_g(sensealg, g) - @. h = wEinv*h + p === nothing && + error("You must have parameters to use parameter sensitivity calculations!") + !(sol.u isa AbstractVector) && error("`u` has to be an AbstractVector.") - return nothing -end + # assert that all ts are hit if concrete solve interface/discrete costs are used + if t !== nothing + @assert sol.t == t + dg = dg_discrete + else + dg = dg_continuous + end -function wBcorrect!(S,sol,g,Nt,sense,sensealg,dg) - @unpack dg_val, pgpu, pgpu_config, numparams, numindvar, uf = sense - @unpack wBinv = S + sense = LSSSensitivityFunction(sensealg, f, f.analytic, f.jac, + f.jac_prototype, f.sparsity, f.paramjac, + u0, sensealg, + p, similar(u0), f.mass_matrix, + f.colorvec, + tspan, g, dg) - for (i,u) in enumerate(sol.u) - _wBinv = @view wBinv[(i-1)*numindvar+1:i*numindvar] - if dg === nothing - if dg_val isa Tuple - SciMLSensitivity.gradient!(dg_val[1], pgpu, u, sensealg, pgpu_config) - @. _wBinv = _wBinv*dg_val[1]/Nt - else - SciMLSensitivity.gradient!(dg_val, pgpu, u, sensealg, pgpu_config) - @. _wBinv = _wBinv*dg_val/Nt - end + @unpack numparams, numindvar = sense + Nt = length(sol.t) + Ndt = Nt - one(Nt) + + # pre-allocate variables + dt = similar(sol.t, Ndt) + umid = Matrix{eltype(u0)}(undef, numindvar, Ndt) + dudt = Matrix{eltype(u0)}(undef, numindvar, Ndt) + # compute their values + discretize_ref_trajectory!(dt, umid, dudt, sol, Ndt) + + S = LSSSchur(dt, u0, numindvar, Nt, Ndt, sensealg.LSSregularizer) + + if sensealg.LSSregularizer isa TimeDilation + g0 = g(u0, p, tspan[1]) else - if dg_val isa Tuple - dg[1](dg_val[1],u,uf.p,nothing,i) - @. _wBinv = _wBinv*dg_val[1]/Nt - else - dg(dg_val,u,uf.p,nothing,i) - @. _wBinv = _wBinv*dg_val/Nt - end + g0 = nothing end - end - return nothing + + b = Vector{eltype(u0)}(undef, numindvar * Ndt) + h = Vector{eltype(u0)}(undef, Ndt) + wa = similar(dt, numindvar * Ndt) + + Δt = tspan[2] - tspan[1] + wB!(S, Δt, Nt, numindvar, dt) + wE!(S, Δt, dt, sensealg.LSSregularizer) + + B!(S, dt, umid, sense, sensealg) + E!(S, dudt, sensealg.LSSregularizer) + F = SchurLU(S) + wBcorrect!(S, sol, g, Nt, sense, sensealg, dg) + + h!(h, g0, g, umid, p, S.wEinv) + + res = similar(u0, numparams) + + AdjointLSSProblem{typeof(sensealg), typeof(sense), typeof(sol), typeof(dt), + typeof(umid), typeof(dudt), + typeof(S), typeof(F), typeof(h), typeof(b), typeof(wa), typeof(Δt), + typeof(g0), typeof(g), typeof(dg), typeof(res)}(sensealg, sense, sol, + dt, umid, dudt, S, F, + h, b, wa, + Δt, Nt, g0, g, dg, + res) end -function shadow_adjoint(prob::AdjointLSSProblem; sensealg=prob.sensealg) - shadow_adjoint(prob,sensealg,sensealg.LSSregularizer) +function h!(h, g0, g, u, p, wEinv) + for (j, uj) in enumerate(eachcol(u)) + # compute objective + h[j] = g(uj, p, nothing) + end + h .= -(h .- mean(h)) / (size(u)[2]) + + @. h = wEinv * h + + return nothing end -function shadow_adjoint(prob::AdjointLSSProblem,sensealg::AdjointLSS,LSSregularizer::TimeDilation) - @unpack sol, S, F, Δt, diffcache, h, b, wa, res, g, g0, dg, umid = prob - @unpack wBinv, B, E = S - @unpack dg_val, pgpp, pgpp_config, numparams, numindvar, uf, f, f_cache, pJ, pf, paramjac_config = diffcache - @unpack t0skip, t1skip = LSSregularizer - - b .= E*h + B*wBinv - wa .= F\b - - n0 = searchsortedfirst(sol.t, sol.t[1]+t0skip) - n1 = searchsortedfirst(sol.t, sol.t[end]-t1skip) - - umidres = @view umid[:,n0:n1-1] - wares = @view wa[(n0-1)*numindvar+1:(n1-1)*numindvar] - - # reset - res .*= false - - if dg_val isa Tuple - for (j,u) in enumerate(eachcol(umidres)) - if dg === nothing - SciMLSensitivity.gradient!(dg_val[2], pgpp, uf.p, sensealg, pgpp_config) - @. res += dg_val[2] - else - dg[2](dg_val[2],u,uf.p,nothing,n0+j-1) - @. res += dg_val[2] - end +function wBcorrect!(S, sol, g, Nt, sense, sensealg, dg) + @unpack dg_val, pgpu, pgpu_config, numparams, numindvar, uf = sense + @unpack wBinv = S + + for (i, u) in enumerate(sol.u) + _wBinv = @view wBinv[((i - 1) * numindvar + 1):(i * numindvar)] + if dg === nothing + if dg_val isa Tuple + SciMLSensitivity.gradient!(dg_val[1], pgpu, u, sensealg, pgpu_config) + @. _wBinv = _wBinv * dg_val[1] / Nt + else + SciMLSensitivity.gradient!(dg_val, pgpu, u, sensealg, pgpu_config) + @. _wBinv = _wBinv * dg_val / Nt + end + else + if dg_val isa Tuple + dg[1](dg_val[1], u, uf.p, nothing, i) + @. _wBinv = _wBinv * dg_val[1] / Nt + else + dg(dg_val, u, uf.p, nothing, i) + @. _wBinv = _wBinv * dg_val / Nt + end + end end - res ./= (size(umidres)[2]) - end + return nothing +end - for (j,u) in enumerate(eachcol(umidres)) - _wares = @view wares[(j-1)*numindvar+1:j*numindvar] - if DiffEqBase.has_paramjac(f) - f.paramjac(pJ, u, uf.p, pf.t) - else - pf.u = u - jacobian!(pJ, pf, uf.p, f_cache, sensealg, paramjac_config) +function shadow_adjoint(prob::AdjointLSSProblem; sensealg = prob.sensealg) + shadow_adjoint(prob, sensealg, sensealg.LSSregularizer) +end + +function shadow_adjoint(prob::AdjointLSSProblem, sensealg::AdjointLSS, + LSSregularizer::TimeDilation) + @unpack sol, S, F, Δt, diffcache, h, b, wa, res, g, g0, dg, umid = prob + @unpack wBinv, B, E = S + @unpack dg_val, pgpp, pgpp_config, numparams, numindvar, uf, f, f_cache, pJ, pf, paramjac_config = diffcache + @unpack t0skip, t1skip = LSSregularizer + + b .= E * h + B * wBinv + wa .= F \ b + + n0 = searchsortedfirst(sol.t, sol.t[1] + t0skip) + n1 = searchsortedfirst(sol.t, sol.t[end] - t1skip) + + umidres = @view umid[:, n0:(n1 - 1)] + wares = @view wa[((n0 - 1) * numindvar + 1):((n1 - 1) * numindvar)] + + # reset + res .*= false + + if dg_val isa Tuple + for (j, u) in enumerate(eachcol(umidres)) + if dg === nothing + SciMLSensitivity.gradient!(dg_val[2], pgpp, uf.p, sensealg, pgpp_config) + @. res += dg_val[2] + else + dg[2](dg_val[2], u, uf.p, nothing, n0 + j - 1) + @. res += dg_val[2] + end + end + res ./= (size(umidres)[2]) end - res .+= pJ'*_wares - end + for (j, u) in enumerate(eachcol(umidres)) + _wares = @view wares[((j - 1) * numindvar + 1):(j * numindvar)] + if DiffEqBase.has_paramjac(f) + f.paramjac(pJ, u, uf.p, pf.t) + else + pf.u = u + jacobian!(pJ, pf, uf.p, f_cache, sensealg, paramjac_config) + end + + res .+= pJ' * _wares + end - return res + return res end -check_for_g(sensealg::Union{ForwardLSS,AdjointLSS},g)=((sensealg.LSSregularizer isa TimeDilation && g===nothing) && error("Time dilation needs explicit knowledge of g. Either pass `g` as a kwarg to `ForwardLSS(g=g)` or `AdjointLSS(g=g)` or use ForwardLSS/AdjointLSS with windowing.")) +function check_for_g(sensealg::Union{ForwardLSS, AdjointLSS}, g) + ((sensealg.LSSregularizer isa TimeDilation && g === nothing) && + error("Time dilation needs explicit knowledge of g. Either pass `g` as a kwarg to `ForwardLSS(g=g)` or `AdjointLSS(g=g)` or use ForwardLSS/AdjointLSS with windowing.")) +end diff --git a/src/nilsas.jl b/src/nilsas.jl index 74deb48fc..e6244f780 100644 --- a/src/nilsas.jl +++ b/src/nilsas.jl @@ -1,439 +1,457 @@ -struct NILSASSensitivityFunction{iip,NILSS,ASF,Mtype} <: DiffEqBase.AbstractODEFunction{iip} - nilss::NILSS - S::ASF # Adjoint sensitivity function - M::Mtype - discrete::Bool +struct NILSASSensitivityFunction{iip, NILSS, ASF, Mtype} <: + DiffEqBase.AbstractODEFunction{iip} + nilss::NILSS + S::ASF # Adjoint sensitivity function + M::Mtype + discrete::Bool end -struct QuadratureCache{A1,A2,A3,A4,A5} - dwv::A1 - dwf::A1 - dwfs::A2 - dvf::A3 - dvfs::A4 - dJs::A4 - C::A5 - R::A5 - b::A1 +struct QuadratureCache{A1, A2, A3, A4, A5} + dwv::A1 + dwf::A1 + dwfs::A2 + dvf::A3 + dvfs::A4 + dJs::A4 + C::A5 + R::A5 + b::A1 end function QuadratureCache(u0, M, nseg, numparams) - dwv = Array{eltype(u0)}(undef, M, nseg) - dwf = Array{eltype(u0)}(undef, M, nseg) - dwfs = Array{eltype(u0)}(undef, numparams*M, nseg) - dvf = Array{eltype(u0)}(undef, 1, nseg) - dvfs = Array{eltype(u0)}(undef, numparams, nseg) - dJs = Array{eltype(u0)}(undef, numparams, nseg) - C = Array{eltype(u0)}(undef, M, M, nseg) - R = Array{eltype(u0)}(undef, M, M, nseg) - b = Array{eltype(u0)}(undef, M, nseg) - - QuadratureCache{typeof(dwv),typeof(dwfs),typeof(dvf),typeof(dvfs),typeof(C)}(dwv,dwf,dwfs,dvf,dvfs,dJs,C,R,b) + dwv = Array{eltype(u0)}(undef, M, nseg) + dwf = Array{eltype(u0)}(undef, M, nseg) + dwfs = Array{eltype(u0)}(undef, numparams * M, nseg) + dvf = Array{eltype(u0)}(undef, 1, nseg) + dvfs = Array{eltype(u0)}(undef, numparams, nseg) + dJs = Array{eltype(u0)}(undef, numparams, nseg) + C = Array{eltype(u0)}(undef, M, M, nseg) + R = Array{eltype(u0)}(undef, M, M, nseg) + b = Array{eltype(u0)}(undef, M, nseg) + + QuadratureCache{typeof(dwv), typeof(dwfs), typeof(dvf), typeof(dvfs), typeof(C)}(dwv, + dwf, + dwfs, + dvf, + dvfs, + dJs, C, + R, b) end -struct NILSASProblem{A,NILSS,Aprob,Qcache,solType,z0Type,tType,DG1,DG2,G,T} - sensealg::A - nilss::NILSS # diffcache - adjoint_prob::Aprob - quadcache::Qcache - sol::solType - z0::z0Type - t::tType - dg_discrete::DG1 - dg_continuous::DG2 - g::G - T_seg::T - dtsave::T +struct NILSASProblem{A, NILSS, Aprob, Qcache, solType, z0Type, tType, DG1, DG2, G, T} + sensealg::A + nilss::NILSS # diffcache + adjoint_prob::Aprob + quadcache::Qcache + sol::solType + z0::z0Type + t::tType + dg_discrete::DG1 + dg_continuous::DG2 + g::G + T_seg::T + dtsave::T end - function NILSASProblem(sol, sensealg::NILSAS, - t=nothing, dg_discrete = nothing, dg_continuous = nothing; kwargs...) - - @unpack f, p, u0, tspan = sol.prob - @unpack nseg, nstep, rng, adjoint_sensealg, M, g = sensealg #number of segments on time interval, number of steps saved on each segment - - numindvar = length(u0) - numparams = length(p) - - # some shadowing sensealgs require knowledge of g - check_for_g(sensealg,g) - - # sensealg choice - adjoint_sensealg === nothing && (adjoint_sensealg = automatic_sensealg_choice(sol.prob,u0,p,false)) - - p === nothing && error("You must have parameters to use parameter sensitivity calculations!") - !(u0 isa AbstractVector) && error("`u` has to be an AbstractVector.") - - nstep <= 1 && error("At least the start and the end point of each segment must be stored. Please use `nstep >=2`.") - - !(u0 isa AbstractVector) && error("`u` has to be an AbstractVector.") - - # segmentation: determine length of segmentation and spacing between saved points - T_seg = (tspan[2]-tspan[1])/nseg # length of each segment - dtsave = T_seg/(nstep-1) - - # homogenous + inhomogenous adjoint sensitivity problem - # assign initial values to y, vstar, w - y = copy(sol.u[end]) - z0 = terminate_conditions(adjoint_sensealg,rng,f,y,p,tspan[2],numindvar,numparams,M) - nilss = NILSSSensitivityFunction(sensealg,f,u0,p,tspan,g,dg_continuous) - tspan = (tspan[2] - T_seg, tspan[2]) - checkpoints = tspan[1]:dtsave:tspan[2] + t = nothing, dg_discrete = nothing, dg_continuous = nothing; + kwargs...) + @unpack f, p, u0, tspan = sol.prob + @unpack nseg, nstep, rng, adjoint_sensealg, M, g = sensealg #number of segments on time interval, number of steps saved on each segment + + numindvar = length(u0) + numparams = length(p) + + # some shadowing sensealgs require knowledge of g + check_for_g(sensealg, g) + + # sensealg choice + adjoint_sensealg === nothing && + (adjoint_sensealg = automatic_sensealg_choice(sol.prob, u0, p, false)) + + p === nothing && + error("You must have parameters to use parameter sensitivity calculations!") + !(u0 isa AbstractVector) && error("`u` has to be an AbstractVector.") + + nstep <= 1 && + error("At least the start and the end point of each segment must be stored. Please use `nstep >=2`.") + + !(u0 isa AbstractVector) && error("`u` has to be an AbstractVector.") + + # segmentation: determine length of segmentation and spacing between saved points + T_seg = (tspan[2] - tspan[1]) / nseg # length of each segment + dtsave = T_seg / (nstep - 1) + + # homogenous + inhomogenous adjoint sensitivity problem + # assign initial values to y, vstar, w + y = copy(sol.u[end]) + z0 = terminate_conditions(adjoint_sensealg, rng, f, y, p, tspan[2], numindvar, + numparams, M) + nilss = NILSSSensitivityFunction(sensealg, f, u0, p, tspan, g, dg_continuous) + tspan = (tspan[2] - T_seg, tspan[2]) + checkpoints = tspan[1]:dtsave:tspan[2] + + adjoint_prob = ODEAdjointProblem(sol, adjoint_sensealg, t, dg_discrete, dg_continuous, + g; + checkpoints = checkpoints, + z0 = z0, M = M, nilss = nilss, tspan = tspan, + kwargs...) + + # pre-allocate variables for integration Eq.(23) in NILSAS paper. + quadcache = QuadratureCache(u0, M, nseg, numparams) + + NILSASProblem{typeof(sensealg), typeof(nilss), typeof(adjoint_prob), typeof(quadcache), + typeof(sol), typeof(z0), typeof(t), typeof(dg_discrete), + typeof(dg_continuous), typeof(g), typeof(T_seg)}(sensealg, nilss, + adjoint_prob, quadcache, + sol, deepcopy(z0), t, + dg_discrete, + dg_continuous, g, T_seg, + dtsave) +end - adjoint_prob = ODEAdjointProblem(sol,adjoint_sensealg,t,dg_discrete,dg_continuous,g; - checkpoints=checkpoints, - z0=z0, M=M, nilss=nilss, tspan=tspan, kwargs...) +function terminate_conditions(alg::BacksolveAdjoint, rng, f, y, p, t, numindvar, numparams, + M) + if isinplace(f) + f_unit = zero(y) + f(f_unit, y, p, t) + normalize!(f_unit) + else + f_unit = f(y, p, t) + normalize!(f_unit) + end - # pre-allocate variables for integration Eq.(23) in NILSAS paper. - quadcache = QuadratureCache(u0, M, nseg, numparams) + if M > 1 + W = rand(rng, numindvar, M - 1) + W .-= (f_unit' * W) .* f_unit + w, _ = qr(W) + _w = @view w[:, 1:(M - 1)] + W = hcat(_w, f_unit) + else + W = f_unit + end + vst = zeros(numindvar) - NILSASProblem{typeof(sensealg),typeof(nilss),typeof(adjoint_prob),typeof(quadcache), - typeof(sol),typeof(z0),typeof(t),typeof(dg_discrete),typeof(dg_continuous),typeof(g),typeof(T_seg)}(sensealg, nilss, - adjoint_prob, quadcache, sol, deepcopy(z0), t, dg_discrete, dg_continuous, g, T_seg, dtsave) -end + # quadratures + C = zeros(M, M) + dwv = zeros(M) + dwf = zeros(M) + dvf = zeros(1) + dJs = zeros(numparams) -function terminate_conditions(alg::BacksolveAdjoint,rng,f,y,p,t,numindvar,numparams,M) - if isinplace(f) - f_unit = zero(y) - f(f_unit,y,p,t) - normalize!(f_unit) - else - f_unit = f(y,p,t) - normalize!(f_unit) - end - - if M>1 - W = rand(rng,numindvar,M-1) - W .-= (f_unit'*W) .* f_unit - w, _ = qr(W) - _w = @view w[:,1:M-1] - W = hcat(_w, f_unit) - else - W = f_unit - end - vst = zeros(numindvar) - - # quadratures - C = zeros(M,M) - dwv = zeros(M) - dwf = zeros(M) - dvf = zeros(1) - dJs = zeros(numparams) - - return ArrayPartition([vst;vec(W)],zeros(numparams*(M+1)),y,C,dwv,dwf,dvf,dJs) + return ArrayPartition([vst; vec(W)], zeros(numparams * (M + 1)), y, C, dwv, dwf, dvf, + dJs) end +function split_states(du, u, t, NS::NILSASSensitivityFunction, j; update = true) + @unpack nilss, S = NS + @unpack numindvar, numparams = nilss -function split_states(du,u,t,NS::NILSASSensitivityFunction,j;update=true) - @unpack nilss, S = NS - @unpack numindvar,numparams = nilss - - indx1 = (j-1)*(numindvar) + 1 - indx2 = indx1 + (numindvar-1) - indx3 = (j-1)*(numparams) + 1 - indx4 = indx3 + (numparams-1) - - λ = @view u.x[1][indx1:indx2] - grad = @view u.x[2][indx3:indx4] - _y = u.x[3] + indx1 = (j - 1) * (numindvar) + 1 + indx2 = indx1 + (numindvar - 1) + indx3 = (j - 1) * (numparams) + 1 + indx4 = indx3 + (numparams - 1) - # like ODE/Drift term and scalar noise - dλ = @view du.x[1][indx1:indx2] - dgrad = @view du.x[2][indx3:indx4] - dy = du.x[3] + λ = @view u.x[1][indx1:indx2] + grad = @view u.x[2][indx3:indx4] + _y = u.x[3] + # like ODE/Drift term and scalar noise + dλ = @view du.x[1][indx1:indx2] + dgrad = @view du.x[2][indx3:indx4] + dy = du.x[3] - λ,grad,_y,dλ,dgrad,dy + λ, grad, _y, dλ, dgrad, dy end -function split_quadratures(du,u,t,NS::NILSASSensitivityFunction;update=true) - @unpack nilss, S = NS - @unpack numindvar,numparams = nilss +function split_quadratures(du, u, t, NS::NILSASSensitivityFunction; update = true) + @unpack nilss, S = NS + @unpack numindvar, numparams = nilss - C = u.x[4] - dwv = u.x[5] - dwf = u.x[6] - dvf = u.x[7] - dJs = u.x[8] + C = u.x[4] + dwv = u.x[5] + dwf = u.x[6] + dvf = u.x[7] + dJs = u.x[8] - dC = du.x[4] - ddwv = du.x[5] - ddwf = du.x[6] - ddvf = du.x[7] - ddJs = du.x[8] + dC = du.x[4] + ddwv = du.x[5] + ddwf = du.x[6] + ddvf = du.x[7] + ddJs = du.x[8] - dC,ddwv,ddwf,ddvf,ddJs, C,dwv,dwf,dvf,dJs + dC, ddwv, ddwf, ddvf, ddJs, C, dwv, dwf, dvf, dJs end +function (NS::NILSASSensitivityFunction)(du, u, p, t) + @unpack nilss, S, M = NS + @unpack f, dg, dg_val, pgpu, pgpu_config, pgpp, pgpp_config, numparams, numindvar, alg = nilss + @unpack y, discrete = S -function (NS::NILSASSensitivityFunction)(du,u,p,t) - @unpack nilss, S, M = NS - @unpack f, dg, dg_val, pgpu, pgpu_config, pgpp, pgpp_config, numparams, numindvar, alg = nilss - @unpack y, discrete = S + λ, _, _y, dλ, dgrad, dy = split_states(du, u, t, NS, 1) + copyto!(vec(y), _y) - λ,_,_y,dλ,dgrad,dy = split_states(du,u,t,NS,1) - copyto!(vec(y), _y) - - # compute gradient of objective wrt. state - if !discrete - accumulate_cost!(dg, y, p, t, nilss) - end - - # loop over all adjoint states - for j=1:M+1 - λ,_,_,dλ,dgrad,dy = split_states(du,u,t,NS,j) - vecjacobian!(dλ, y, λ, p, t, S, dgrad=dgrad, dy=dy) - dλ .*= -1 - dgrad .*= -1 + # compute gradient of objective wrt. state + if !discrete + accumulate_cost!(dg, y, p, t, nilss) + end - if j==1 - # j = 1 is the inhomogenous adjoint solution - if !discrete - if dg_val isa Tuple - dλ .-= vec(dg_val[1]) - else - dλ .-= vec(dg_val) + # loop over all adjoint states + for j in 1:(M + 1) + λ, _, _, dλ, dgrad, dy = split_states(du, u, t, NS, j) + vecjacobian!(dλ, y, λ, p, t, S, dgrad = dgrad, dy = dy) + dλ .*= -1 + dgrad .*= -1 + + if j == 1 + # j = 1 is the inhomogenous adjoint solution + if !discrete + if dg_val isa Tuple + dλ .-= vec(dg_val[1]) + else + dλ .-= vec(dg_val) + end + end end - end end - end - - # quadratures - dC,ddwv,ddwf,ddvf,ddJs, _,_,_,_,_ = split_quadratures(du,u,t,NS) - # j = 1 is the inhomogenous adjoint solution - λv,_,_,_,_,dy = split_states(du,u,t,NS,1) - ddvf .= -dot(λv,dy) - for j=1:M - λ,_,_,_,_,_ = split_states(du,u,t,NS,j+1) - ddwf[j] = -dot(λ,dy) - ddwv[j] = -dot(λ,λv) - for i=j+1:M - λ2,_,_,_,_,_ = split_states(du,u,t,NS,i+1) - dC[j,i] = -dot(λ,λ2) - dC[i,j] = dC[j,i] + + # quadratures + dC, ddwv, ddwf, ddvf, ddJs, _, _, _, _, _ = split_quadratures(du, u, t, NS) + # j = 1 is the inhomogenous adjoint solution + λv, _, _, _, _, dy = split_states(du, u, t, NS, 1) + ddvf .= -dot(λv, dy) + for j in 1:M + λ, _, _, _, _, _ = split_states(du, u, t, NS, j + 1) + ddwf[j] = -dot(λ, dy) + ddwv[j] = -dot(λ, λv) + for i in (j + 1):M + λ2, _, _, _, _, _ = split_states(du, u, t, NS, i + 1) + dC[j, i] = -dot(λ, λ2) + dC[i, j] = dC[j, i] + end + dC[j, j] = -dot(λ, λ) end - dC[j,j] = -dot(λ,λ) - end - if dg_val isa Tuple && !discrete - ddJs .= -vec(dg_val[2]) - end + if dg_val isa Tuple && !discrete + ddJs .= -vec(dg_val[2]) + end - return nothing + return nothing end - function accumulate_cost!(dg, y, p, t, nilss::NILSSSensitivityFunction) - @unpack dg_val, pgpu, pgpu_config, pgpp, pgpp_config, alg = nilss + @unpack dg_val, pgpu, pgpu_config, pgpp, pgpp_config, alg = nilss - if dg===nothing - if dg_val isa Tuple - SciMLSensitivity.gradient!(dg_val[1],pgpu,y,alg,pgpu_config) - SciMLSensitivity.gradient!(dg_val[2],pgpp,y,alg,pgpp_config) - else - SciMLSensitivity.gradient!(dg_val,pgpu,y,alg,pgpu_config) - end - else - if dg_val isa Tuple - dg[1](dg_val[1],y,p,t) - dg[2](dg_val[2],y,p,t) + if dg === nothing + if dg_val isa Tuple + SciMLSensitivity.gradient!(dg_val[1], pgpu, y, alg, pgpu_config) + SciMLSensitivity.gradient!(dg_val[2], pgpp, y, alg, pgpp_config) + else + SciMLSensitivity.gradient!(dg_val, pgpu, y, alg, pgpu_config) + end else - dg(dg_val,y,p,t) + if dg_val isa Tuple + dg[1](dg_val[1], y, p, t) + dg[2](dg_val[2], y, p, t) + else + dg(dg_val, y, p, t) + end end - end - return nothing + return nothing end -function adjoint_sense(prob::NILSASProblem,nilsas::NILSAS,alg; kwargs...) - - @unpack M, nseg, nstep, adjoint_sensealg = nilsas - @unpack sol, nilss, z0, t, dg_discrete, dg_continuous, g, T_seg, dtsave, adjoint_prob = prob - @unpack u0, tspan = adjoint_prob - - copyto!(z0,u0) - - @assert haskey(adjoint_prob.kwargs, :callback) - # get loss callback - cb = adjoint_prob.kwargs[:callback] - - # adjoint sensitivities on segments - for iseg=nseg:-1:1 - t1 = tspan[1]-(nseg-iseg+1)*T_seg - t2 = tspan[1]-(nseg-iseg)*T_seg - checkpoints=t1:dtsave:t2 - _prob = ODEAdjointProblem(sol, adjoint_sensealg, t, dg_discrete, dg_continuous, g; - checkpoints=checkpoints, z0=z0, M=M, nilss=nilss, tspan=(t1, t2), kwargs...) - _sol = solve(_prob,alg; save_everystep=false,save_start=false,saveat=eltype(sol[1])[], - dt = dtsave, - tstops=checkpoints, - callback = cb, - kwargs...) - - # renormalize at interfaces and store quadratures - # update sense problem - renormalize!(prob, _sol, z0, iseg) - end - return nothing +function adjoint_sense(prob::NILSASProblem, nilsas::NILSAS, alg; kwargs...) + @unpack M, nseg, nstep, adjoint_sensealg = nilsas + @unpack sol, nilss, z0, t, dg_discrete, dg_continuous, g, T_seg, dtsave, adjoint_prob = prob + @unpack u0, tspan = adjoint_prob + + copyto!(z0, u0) + + @assert haskey(adjoint_prob.kwargs, :callback) + # get loss callback + cb = adjoint_prob.kwargs[:callback] + + # adjoint sensitivities on segments + for iseg in nseg:-1:1 + t1 = tspan[1] - (nseg - iseg + 1) * T_seg + t2 = tspan[1] - (nseg - iseg) * T_seg + checkpoints = t1:dtsave:t2 + _prob = ODEAdjointProblem(sol, adjoint_sensealg, t, dg_discrete, dg_continuous, g; + checkpoints = checkpoints, z0 = z0, M = M, nilss = nilss, + tspan = (t1, t2), kwargs...) + _sol = solve(_prob, alg; save_everystep = false, save_start = false, + saveat = eltype(sol[1])[], + dt = dtsave, + tstops = checkpoints, + callback = cb, + kwargs...) + + # renormalize at interfaces and store quadratures + # update sense problem + renormalize!(prob, _sol, z0, iseg) + end + return nothing end function renormalize!(prob::NILSASProblem, sol, z0, iseg) - @unpack quadcache, nilss, sensealg = prob - @unpack M = sensealg - @unpack numparams, numindvar = nilss - @unpack R,b = quadcache + @unpack quadcache, nilss, sensealg = prob + @unpack M = sensealg + @unpack numparams, numindvar = nilss + @unpack R, b = quadcache - x = sol.u[end].x - # vstar_right (inhomogenous adjoint on the rhs of the interface) - vstar = @view x[1][1:numindvar] - # homogenous adjoint on the rhs of the interface - W = @view x[1][numindvar+1:end] - W = reshape(W, numindvar, M) + x = sol.u[end].x + # vstar_right (inhomogenous adjoint on the rhs of the interface) + vstar = @view x[1][1:numindvar] + # homogenous adjoint on the rhs of the interface + W = @view x[1][(numindvar + 1):end] + W = reshape(W, numindvar, M) - Q_, R_ = qr(W) - Q = @view Q_[:,1:M] - b_ = (Q'*vstar) + Q_, R_ = qr(W) + Q = @view Q_[:, 1:M] + b_ = (Q' * vstar) - # store R and b to solve NILSAS problem - copyto!( (@view R[:,:,iseg]), R_) - copyto!( (@view b[:,iseg]), b_) + # store R and b to solve NILSAS problem + copyto!((@view R[:, :, iseg]), R_) + copyto!((@view b[:, iseg]), b_) - # store quadrature values - store_quad(quadcache, x, numparams, iseg) + # store quadrature values + store_quad(quadcache, x, numparams, iseg) - # reset z0 - reset!(z0, numindvar, vstar, b_, Q) + # reset z0 + reset!(z0, numindvar, vstar, b_, Q) - return nothing + return nothing end function store_quad(quadcache, x, numparams, iseg) - @unpack dwv,dwf,dwfs,dvf,dvfs,dJs,C = quadcache - - grad_vfs = @view x[2][1:numparams] - copyto!( (@view dvfs[:,iseg]), grad_vfs) - - grad_wfs = @view x[2][numparams+1:end] - copyto!( (@view dwfs[:,iseg]), grad_wfs) - - # C_i = x[4] - copyto!( (@view C[:,:,iseg]), x[4]) - # dwv_i = x[5] - copyto!( (@view dwv[:,iseg]), x[5]) - # dwf_i = x[6] - copyto!( (@view dwf[:,iseg]), x[6]) - # dvf_i = x[7] - copyto!( (@view dvf[:,iseg]), x[7]) - # dJs_i = x[8] - copyto!( (@view dJs[:,iseg]), x[8]) - return nothing + @unpack dwv, dwf, dwfs, dvf, dvfs, dJs, C = quadcache + + grad_vfs = @view x[2][1:numparams] + copyto!((@view dvfs[:, iseg]), grad_vfs) + + grad_wfs = @view x[2][(numparams + 1):end] + copyto!((@view dwfs[:, iseg]), grad_wfs) + + # C_i = x[4] + copyto!((@view C[:, :, iseg]), x[4]) + # dwv_i = x[5] + copyto!((@view dwv[:, iseg]), x[5]) + # dwf_i = x[6] + copyto!((@view dwf[:, iseg]), x[6]) + # dvf_i = x[7] + copyto!((@view dvf[:, iseg]), x[7]) + # dJs_i = x[8] + copyto!((@view dJs[:, iseg]), x[8]) + return nothing end function reset!(z0, numindvar, vstar, b, Q) - # modify z0 - x0 = z0.x - - # vstar_left - v = @view x0[1][1:numindvar] - v .= vstar - vec(b'*Q') - - # W_left (homogenous adjoint on lhs of the interface) - w = @view x0[1][numindvar+1:end] - w .= vec(Q) - - # reset all other values t0 zero - x0[2] .*= false - x0[4] .*= false - x0[5] .*= false - x0[6] .*= false - x0[7] .*= false - x0[8] .*= false - return nothing + # modify z0 + x0 = z0.x + + # vstar_left + v = @view x0[1][1:numindvar] + v .= vstar - vec(b' * Q') + + # W_left (homogenous adjoint on lhs of the interface) + w = @view x0[1][(numindvar + 1):end] + w .= vec(Q) + + # reset all other values t0 zero + x0[2] .*= false + x0[4] .*= false + x0[5] .*= false + x0[6] .*= false + x0[7] .*= false + x0[8] .*= false + return nothing end function nilsas_min(cache::QuadratureCache) - @unpack dwv,dwf,dvf,C,R,b = cache - - # Construct Schur complement of the Lagrange multiplier method of the NILSAS problem. - - # see description in Appendix A of Nilsas paper. - # M= # unstable CLVs, K = # segments - M, K = size(dwv) - - # construct Cinv - # Cinv is a block diagonal matrix - Cinv = zeros(eltype(C), M*K, M*K) - - for i=1:K - Ci = @view C[:, :, i] - _Cinv = @view Cinv[(i-1)*M+1:i*M, (i-1)*M+1:i*M] - Ciinv = inv(Ci) - copyto!(_Cinv,Ciinv) - end - - # construct B, also very sparse if K >> M - B = zeros(eltype(C), M*K-M+1, M*K) - - for i=1:K - if i> M + B = zeros(eltype(C), M * K - M + 1, M * K) + + for i in 1:K + if i < K + # off diagonal Rs + _B = @view B[((i - 1) * M + 1):(i * M), (i * M + 1):((i + 1) * M)] + _R = @view R[:, :, i + 1] + copyto!(_B, _R) + _B .*= -1 + + # diagonal ones + for j in 1:M + B[(i - 1) * M + j, (i - 1) * M + j] = one(eltype(R)) + end + end + # last row + dwfi = dwf[:, i] + _B = @view B[end, ((i - 1) * M + 1):(i * M)] + copyto!(_B, dwfi) + end + + # construct d + d = vec(dwv) - # construct b - _b = [b[M+1:end]; -sum(dvf)] + # construct b + _b = [b[(M + 1):end]; -sum(dvf)] - # compute Lagrange multiplier - λ = (-B*Cinv*B') \ (B*Cinv*d + _b) + # compute Lagrange multiplier + λ = (-B * Cinv * B') \ (B * Cinv * d + _b) - # return a - return reshape(-Cinv*(B'*λ + d), M, K) + # return a + return reshape(-Cinv * (B' * λ + d), M, K) end -function shadow_adjoint(prob::NILSASProblem,alg; sensealg=prob.sensealg, kwargs...) - shadow_adjoint(prob,sensealg,alg; kwargs...) +function shadow_adjoint(prob::NILSASProblem, alg; sensealg = prob.sensealg, kwargs...) + shadow_adjoint(prob, sensealg, alg; kwargs...) end -function shadow_adjoint(prob::NILSASProblem,sensealg::NILSAS,alg; kwargs...) +function shadow_adjoint(prob::NILSASProblem, sensealg::NILSAS, alg; kwargs...) - # compute adjoint sensitivities - adjoint_sense(prob,sensealg,alg; kwargs...) + # compute adjoint sensitivities + adjoint_sense(prob, sensealg, alg; kwargs...) - # compute NILSAS problem on multiple segments - a = nilsas_min(prob.quadcache) + # compute NILSAS problem on multiple segments + a = nilsas_min(prob.quadcache) - # compute gradient, Eq. (28) -- second part to avoid explicit construction of vbar - @unpack M, nseg = sensealg - @unpack dvfs, dJs, dwfs = prob.quadcache + # compute gradient, Eq. (28) -- second part to avoid explicit construction of vbar + @unpack M, nseg = sensealg + @unpack dvfs, dJs, dwfs = prob.quadcache - res = vec(sum(dvfs,dims=2)) + vec(sum(dJs,dims=2)) - NP = length(res) # number of parameters + res = vec(sum(dvfs, dims = 2)) + vec(sum(dJs, dims = 2)) + NP = length(res) # number of parameters - # loop over segments - for (i,ai) in enumerate(eachcol(a)) - dwfsi = @view dwfs[:,i] - dwfsi = reshape(dwfsi,NP,M) - res .+= dwfsi*ai - end + # loop over segments + for (i, ai) in enumerate(eachcol(a)) + dwfsi = @view dwfs[:, i] + dwfsi = reshape(dwfsi, NP, M) + res .+= dwfsi * ai + end - return res/(nseg*prob.T_seg) + return res / (nseg * prob.T_seg) end -check_for_g(sensealg::NILSAS,g) = (g===nothing && error("To use NILSAS, g must be passed as a kwarg to `NILSAS(g=g)`.")) +function check_for_g(sensealg::NILSAS, g) + (g === nothing && error("To use NILSAS, g must be passed as a kwarg to `NILSAS(g=g)`.")) +end diff --git a/src/nilss.jl b/src/nilss.jl index 0cbe27567..296fab10a 100644 --- a/src/nilss.jl +++ b/src/nilss.jl @@ -1,566 +1,584 @@ -struct NILSSSensitivityFunction{iip,F,Alg, - PGPU,PGPP,CONFU,CONGP,DGVAL,DG,jType,RefType} <: DiffEqBase.AbstractODEFunction{iip} - f::F - alg::Alg - numparams::Int - numindvar::Int - pgpu::PGPU - pgpp::PGPP - pgpu_config::CONFU - pgpp_config::CONGP - dg_val::DGVAL - dg::DG - jevery::jType # if concrete_solve interface for discrete cost functions is used - cur_time::RefType +struct NILSSSensitivityFunction{iip, F, Alg, + PGPU, PGPP, CONFU, CONGP, DGVAL, DG, jType, RefType} <: + DiffEqBase.AbstractODEFunction{iip} + f::F + alg::Alg + numparams::Int + numindvar::Int + pgpu::PGPU + pgpp::PGPP + pgpu_config::CONFU + pgpp_config::CONGP + dg_val::DGVAL + dg::DG + jevery::jType # if concrete_solve interface for discrete cost functions is used + cur_time::RefType end -function NILSSSensitivityFunction(sensealg,f,u0,p,tspan,g,dg,jevery=nothing,cur_time=nothing) - - numparams = length(p) - numindvar = length(u0) - - # compute gradients of objective - if dg !== nothing - pgpu = nothing - pgpp = nothing - pgpu_config = nothing - pgpp_config = nothing - if dg isa Tuple && length(dg) == 2 - dg_val = (similar(u0, numindvar),similar(u0, numparams)) - dg_val[1] .= false - dg_val[2] .= false +function NILSSSensitivityFunction(sensealg, f, u0, p, tspan, g, dg, jevery = nothing, + cur_time = nothing) + numparams = length(p) + numindvar = length(u0) + + # compute gradients of objective + if dg !== nothing + pgpu = nothing + pgpp = nothing + pgpu_config = nothing + pgpp_config = nothing + if dg isa Tuple && length(dg) == 2 + dg_val = (similar(u0, numindvar), similar(u0, numparams)) + dg_val[1] .= false + dg_val[2] .= false + else + dg_val = similar(u0, numindvar) # number of funcs size + dg_val .= false + end else - dg_val = similar(u0, numindvar) # number of funcs size - dg_val .= false + pgpu = UGradientWrapper(g, tspan[1], p) # ∂g∂u + pgpp = ParamGradientWrapper(g, tspan[1], u0) #∂g∂p + pgpu_config = build_grad_config(sensealg, pgpu, u0, tspan[1]) + pgpp_config = build_grad_config(sensealg, pgpp, u0, tspan[1]) + dg_val = (similar(u0, numindvar), similar(u0, numparams)) + dg_val[1] .= false + dg_val[2] .= false end - else - pgpu = UGradientWrapper(g,tspan[1],p) # ∂g∂u - pgpp = ParamGradientWrapper(g,tspan[1],u0) #∂g∂p - pgpu_config = build_grad_config(sensealg,pgpu,u0,tspan[1]) - pgpp_config = build_grad_config(sensealg,pgpp,u0,tspan[1]) - dg_val = (similar(u0, numindvar),similar(u0, numparams)) - dg_val[1] .= false - dg_val[2] .= false - end - - NILSSSensitivityFunction{isinplace(f),typeof(f),typeof(sensealg), - typeof(pgpu),typeof(pgpp),typeof(pgpu_config),typeof(pgpp_config),typeof(dg_val),typeof(dg),typeof(jevery),typeof(cur_time)}( - f,sensealg,numparams,numindvar,pgpu,pgpp,pgpu_config,pgpp_config,dg_val,dg,jevery,cur_time) -end - -struct NILSSProblem{A,CacheType,FSprob,probType,u0Type,vstar0Type,w0Type, - TType,dtType,gType,yType,vstarType, - wType,RType,bType,weightType,CType,dType,BType,aType,vType,xiType, - G,DG,resType} - sensealg::A - diffcache::CacheType - forward_prob::FSprob - prob::probType - u0::u0Type - vstar0::vstar0Type - w0::w0Type - nus::Int - T_seg::TType - dtsave::dtType - gsave::gType - y::yType - dudt::yType - dgdu::yType - vstar::vstarType - vstar_perp::vstarType - w::wType - w_perp::wType - R::RType - b::bType - weight::weightType - Cinv::CType - d::dType - B::BType - a::aType - v::vType - v_perp::vType - ξ::xiType - g::G - dg::DG - res::resType + NILSSSensitivityFunction{isinplace(f), typeof(f), typeof(sensealg), + typeof(pgpu), typeof(pgpp), typeof(pgpu_config), + typeof(pgpp_config), typeof(dg_val), typeof(dg), + typeof(jevery), typeof(cur_time)}(f, sensealg, numparams, + numindvar, pgpu, pgpp, + pgpu_config, pgpp_config, + dg_val, dg, jevery, cur_time) end - -function NILSSProblem(prob, sensealg::NILSS; - t=nothing, dg_discrete = nothing, dg_continuous = nothing, - kwargs...) - - @unpack f, p, u0, tspan = prob - @unpack nseg, nstep, nus, rng, g = sensealg #number of segments on time interval, number of steps saved on each segment - - numindvar = length(u0) - numparams = length(p) - - # some shadowing sensealgs require knowledge of g - check_for_g(sensealg,g) - - # integer dimension of the unstable subspace - if nus === nothing - nus = numindvar - one(numindvar) - end - (nus >= numindvar) && error("`nus` must be smaller than `numindvar`.") - - isinplace = DiffEqBase.isinplace(f) - - p === nothing && error("You must have parameters to use parameter sensitivity calculations!") - !(u0 isa AbstractVector) && error("`u` has to be an AbstractVector.") - - # segmentation: determine length of segmentation and spacing between saved points - T_seg = (tspan[2]-tspan[1])/nseg # length of each segment - dtsave = T_seg/(nstep-1) - - # assert that dtsave is chosen such that all ts are hit if concrete solve interface/discrete costs are used - if t!==nothing - @assert t isa StepRangeLen - dt_ts = step(t) - @assert dt_ts >= dtsave - @assert T_seg >= dt_ts - jevery = Int(dt_ts/dtsave) # will throw an inexact error if dt_ts is not a multiple of dtsave. (could be more sophisticated) - cur_time = Ref(1) - dg = dg_discrete - else - jevery = nothing - cur_time = nothing - dg = dg_continuous - end - - # inhomogenous forward sensitivity problem - chunk_size = determine_chunksize(numparams,sensealg) - autodiff = alg_autodiff(sensealg) - difftype = diff_type(sensealg) - autojacvec = sensealg.autojacvec - # homogenous + inhomogenous forward sensitivity problems - forward_prob = ODEForwardSensitivityProblem(f,u0,tspan,p,ForwardSensitivity(chunk_size=chunk_size,autodiff=autodiff, - diff_type=difftype,autojacvec=autojacvec);nus=nus, kwargs...) - - sense = NILSSSensitivityFunction(sensealg,f,u0,p,tspan,g,dg,jevery,cur_time) - - # pre-allocate variables - gsave = Matrix{eltype(u0)}(undef, nstep, nseg) - y = Array{eltype(u0)}(undef, numindvar, nstep, nseg) - dudt = similar(y) - dgdu = similar(y) - vstar = Array{eltype(u0)}(undef, numparams, numindvar, nstep, nseg) # generalization for several parameters numindvar*numparams - vstar_perp = similar(vstar) - w = Array{eltype(u0)}(undef, numindvar, nstep, nseg, nus) - w_perp = similar(w) - - # assign initial values to y, v*, w - y[:,1,1] .= u0 - for i=1:numparams - _vstar = @view vstar[i,:,1,1] - copyto!(_vstar, zero(u0)) - end - for ius=1:nus - _w = @view w[:,1,1,ius] - rand!(rng,_w) - normalize!(_w) - end - - vstar0 = zeros(eltype(u0), numindvar*numparams) - w0 = vec(w[:,1,1,:]) - - R = Array{eltype(u0)}(undef, numparams, nseg-1, nus, nus) - b = Array{eltype(u0)}(undef, numparams, (nseg-1)*nus) - - # a weight matrix for integration, 0.5 at interfaces - weight = ones(1,nstep) - weight[1] /= 2 - weight[end] /= 2 - - # Construct Schur complement of the Lagrange multiplier method of the NILSS problem. - # See the paper on FD-NILSS - # find C^-1 - Cinv = Matrix{eltype(u0)}(undef, nseg*nus, nseg*nus) - Cinv .*= false - d = Vector{eltype(u0)}(undef, nseg*nus) - B = Matrix{eltype(u0)}(undef, (nseg-1)*nus, nseg*nus) - B .*= false - - a = Vector{eltype(u0)}(undef, nseg*nus) - v = Array{eltype(u0)}(undef, numindvar, nstep, nseg) - v_perp = similar(v) - - # only need to use last step in each segment - ξ = Matrix{eltype(u0)}(undef, nseg, 2) - - res = similar(u0, numparams) - - NILSSProblem{typeof(sensealg),typeof(sense),typeof(forward_prob),typeof(prob), - typeof(u0), typeof(vstar0), typeof(w0), - typeof(T_seg),typeof(dtsave),typeof(gsave),typeof(y),typeof(vstar),typeof(w),typeof(R), - typeof(b),typeof(weight),typeof(Cinv),typeof(d),typeof(B),typeof(a),typeof(v),typeof(ξ), - typeof(g),typeof(dg),typeof(res)}(sensealg,sense,forward_prob,prob,u0,vstar0,w0, - nus,T_seg,dtsave,gsave,y,dudt,dgdu,vstar,vstar_perp,w,w_perp,R,b,weight,Cinv,d, - B,a,v,v_perp,ξ,g,dg,res) +struct NILSSProblem{A, CacheType, FSprob, probType, u0Type, vstar0Type, w0Type, + TType, dtType, gType, yType, vstarType, + wType, RType, bType, weightType, CType, dType, BType, aType, vType, + xiType, + G, DG, resType} + sensealg::A + diffcache::CacheType + forward_prob::FSprob + prob::probType + u0::u0Type + vstar0::vstar0Type + w0::w0Type + nus::Int + T_seg::TType + dtsave::dtType + gsave::gType + y::yType + dudt::yType + dgdu::yType + vstar::vstarType + vstar_perp::vstarType + w::wType + w_perp::wType + R::RType + b::bType + weight::weightType + Cinv::CType + d::dType + B::BType + a::aType + v::vType + v_perp::vType + ξ::xiType + g::G + dg::DG + res::resType end +function NILSSProblem(prob, sensealg::NILSS; + t = nothing, dg_discrete = nothing, dg_continuous = nothing, + kwargs...) + @unpack f, p, u0, tspan = prob + @unpack nseg, nstep, nus, rng, g = sensealg #number of segments on time interval, number of steps saved on each segment -function (NS::NILSSForwardSensitivityFunction)(du,u,p,t) - @unpack S, nus = NS - y = @view u[1:S.numindvar] # These are the independent variables - dy = @view du[1:S.numindvar] - S.f(dy,y,p,t) # Make the first part be the ODE + numindvar = length(u0) + numparams = length(p) - # Now do sensitivities - # Compute the Jacobian + # some shadowing sensealgs require knowledge of g + check_for_g(sensealg, g) - if !S.isautojacvec - if has_original_jac(S.f) - S.original_jac(S.J,y,p,t) # Calculate the Jacobian into J + # integer dimension of the unstable subspace + if nus === nothing + nus = numindvar - one(numindvar) + end + (nus >= numindvar) && error("`nus` must be smaller than `numindvar`.") + + isinplace = DiffEqBase.isinplace(f) + + p === nothing && + error("You must have parameters to use parameter sensitivity calculations!") + !(u0 isa AbstractVector) && error("`u` has to be an AbstractVector.") + + # segmentation: determine length of segmentation and spacing between saved points + T_seg = (tspan[2] - tspan[1]) / nseg # length of each segment + dtsave = T_seg / (nstep - 1) + + # assert that dtsave is chosen such that all ts are hit if concrete solve interface/discrete costs are used + if t !== nothing + @assert t isa StepRangeLen + dt_ts = step(t) + @assert dt_ts >= dtsave + @assert T_seg >= dt_ts + jevery = Int(dt_ts / dtsave) # will throw an inexact error if dt_ts is not a multiple of dtsave. (could be more sophisticated) + cur_time = Ref(1) + dg = dg_discrete else - S.uf.t = t - jacobian!(S.J, S.uf, y, S.f_cache, S.alg, S.jac_config) + jevery = nothing + cur_time = nothing + dg = dg_continuous end - end - - if DiffEqBase.has_paramjac(S.f) - S.paramjac(S.pJ,y,p,t) # Calculate the parameter Jacobian into pJ - else - S.pf.t = t - S.pf.u .= y - jacobian!(S.pJ, S.pf, p, S.f_cache, S.alg, S.paramjac_config) - end - - # Compute the parameter derivatives - for j=1:nus+1 - for i in eachindex(p) - indx1 = (j-1)*S.numindvar*1 + i*S.numindvar+1 - indx2 = (j-1)*S.numindvar*1 + (i+1)*S.numindvar - Sj = @view u[indx1:indx2] - dp = @view du[indx1:indx2] - if !S.isautojacvec - mul!(dp,S.J,Sj) - else - jacobianvec!(dp, S.uf, y, Sj, S.alg, S.jac_config) - end - if j == nus+1 - # inhomogenous (otherwise homogenous tangent solution) - dp .+= @view S.pJ[:,i] - end + + # inhomogenous forward sensitivity problem + chunk_size = determine_chunksize(numparams, sensealg) + autodiff = alg_autodiff(sensealg) + difftype = diff_type(sensealg) + autojacvec = sensealg.autojacvec + # homogenous + inhomogenous forward sensitivity problems + forward_prob = ODEForwardSensitivityProblem(f, u0, tspan, p, + ForwardSensitivity(chunk_size = chunk_size, + autodiff = autodiff, + diff_type = difftype, + autojacvec = autojacvec); + nus = nus, kwargs...) + + sense = NILSSSensitivityFunction(sensealg, f, u0, p, tspan, g, dg, jevery, cur_time) + + # pre-allocate variables + gsave = Matrix{eltype(u0)}(undef, nstep, nseg) + y = Array{eltype(u0)}(undef, numindvar, nstep, nseg) + dudt = similar(y) + dgdu = similar(y) + vstar = Array{eltype(u0)}(undef, numparams, numindvar, nstep, nseg) # generalization for several parameters numindvar*numparams + vstar_perp = similar(vstar) + w = Array{eltype(u0)}(undef, numindvar, nstep, nseg, nus) + w_perp = similar(w) + + # assign initial values to y, v*, w + y[:, 1, 1] .= u0 + for i in 1:numparams + _vstar = @view vstar[i, :, 1, 1] + copyto!(_vstar, zero(u0)) + end + for ius in 1:nus + _w = @view w[:, 1, 1, ius] + rand!(rng, _w) + normalize!(_w) end - end - return nothing + + vstar0 = zeros(eltype(u0), numindvar * numparams) + w0 = vec(w[:, 1, 1, :]) + + R = Array{eltype(u0)}(undef, numparams, nseg - 1, nus, nus) + b = Array{eltype(u0)}(undef, numparams, (nseg - 1) * nus) + + # a weight matrix for integration, 0.5 at interfaces + weight = ones(1, nstep) + weight[1] /= 2 + weight[end] /= 2 + + # Construct Schur complement of the Lagrange multiplier method of the NILSS problem. + # See the paper on FD-NILSS + # find C^-1 + Cinv = Matrix{eltype(u0)}(undef, nseg * nus, nseg * nus) + Cinv .*= false + d = Vector{eltype(u0)}(undef, nseg * nus) + B = Matrix{eltype(u0)}(undef, (nseg - 1) * nus, nseg * nus) + B .*= false + + a = Vector{eltype(u0)}(undef, nseg * nus) + v = Array{eltype(u0)}(undef, numindvar, nstep, nseg) + v_perp = similar(v) + + # only need to use last step in each segment + ξ = Matrix{eltype(u0)}(undef, nseg, 2) + + res = similar(u0, numparams) + + NILSSProblem{typeof(sensealg), typeof(sense), typeof(forward_prob), typeof(prob), + typeof(u0), typeof(vstar0), typeof(w0), + typeof(T_seg), typeof(dtsave), typeof(gsave), typeof(y), typeof(vstar), + typeof(w), typeof(R), + typeof(b), typeof(weight), typeof(Cinv), typeof(d), typeof(B), typeof(a), + typeof(v), typeof(ξ), + typeof(g), typeof(dg), typeof(res)}(sensealg, sense, forward_prob, prob, + u0, vstar0, w0, + nus, T_seg, dtsave, gsave, y, dudt, + dgdu, vstar, vstar_perp, w, w_perp, R, + b, weight, Cinv, d, + B, a, v, v_perp, ξ, g, dg, res) end -function forward_sense(prob::NILSSProblem,nilss::NILSS,alg) - #TODO determine a good dtsave (ΔT in paper, see Sec.4.2) - @unpack nus, T_seg, dtsave, vstar, vstar_perp, w, w_perp, R, b, y, dudt, gsave, dgdu, forward_prob, u0, vstar0, w0 = prob - @unpack p, f = forward_prob - @unpack S, sensealg = f - @unpack nseg, nstep = nilss - @unpack numindvar, numparams = S - - # push forward - t1 = forward_prob.tspan[1] - t2 = forward_prob.tspan[1]+T_seg - _prob = ODEForwardSensitivityProblem(S.f,u0,(t1,t2),p,sensealg;nus=nus,w0=w0,v0=vstar0) - - for iseg=1:nseg - # compute y, w, vstar - # _sol is a numindvar*(1+nus+1) x nstep matrix - - dt = (t2 - t1) / (nstep-1) - _sol = Array(solve(_prob, alg, saveat=t1:dt:t2)) - - store_y_w_vstar!(y, w, vstar, _sol, nus, numindvar, numparams, iseg) - - # store dudt, objective g (gsave), and its derivative wrt. to u (dgdu) - dudt_g_dgdu!(dudt, gsave, dgdu, prob, y, forward_prob.p, iseg) - - # calculate w_perp, vstar_perp - perp!(w_perp, vstar_perp, w, vstar, dudt, iseg, numparams, nstep, nus) - - # update sense problem - if iseg < nseg - # renormalize at interfaces - renormalize!(R,b,w_perp,vstar_perp,y,vstar,w,iseg,numparams,nus) - t1 = forward_prob.tspan[1]+iseg*T_seg - t2 = forward_prob.tspan[1]+(iseg+1)*T_seg - _prob = ODEForwardSensitivityProblem(S.f,y[:,1,iseg+1],(t1,t2),p,sensealg; nus=nus, - w0=vec(w[:,1,iseg+1,:]),v0=vec(vstar[:,:,1,iseg+1])) +function (NS::NILSSForwardSensitivityFunction)(du, u, p, t) + @unpack S, nus = NS + y = @view u[1:(S.numindvar)] # These are the independent variables + dy = @view du[1:(S.numindvar)] + S.f(dy, y, p, t) # Make the first part be the ODE + + # Now do sensitivities + # Compute the Jacobian + + if !S.isautojacvec + if has_original_jac(S.f) + S.original_jac(S.J, y, p, t) # Calculate the Jacobian into J + else + S.uf.t = t + jacobian!(S.J, S.uf, y, S.f_cache, S.alg, S.jac_config) + end end - end -end + if DiffEqBase.has_paramjac(S.f) + S.paramjac(S.pJ, y, p, t) # Calculate the parameter Jacobian into pJ + else + S.pf.t = t + S.pf.u .= y + jacobian!(S.pJ, S.pf, p, S.f_cache, S.alg, S.paramjac_config) + end -function store_y_w_vstar!(y, w, vstar, sol, nus, numindvar, numparams, iseg) - # fill y - _y = @view y[:,:,iseg] - copyto!(_y, (@view sol[1:numindvar,:])) - - # fill w - # only calculate w one time, w can be reused for each parameter - for j=1:nus - indx1 = (j-1)*numindvar*1 + numindvar+1 - indx2 = (j-1)*numindvar*1 + 2*numindvar - - _w = @view w[:,:,iseg, j] - copyto!(_w, (@view sol[indx1:indx2,:])) - end - - # fill vstar - for i=1:numparams - indx1 = nus*numindvar*1 + i*numindvar+1 - indx2 = nus*numindvar*1 + (i+1)*numindvar - _vstar = @view vstar[i,:,:,iseg] - copyto!(_vstar, (@view sol[indx1:indx2,:])) - end - - return nothing + # Compute the parameter derivatives + for j in 1:(nus + 1) + for i in eachindex(p) + indx1 = (j - 1) * S.numindvar * 1 + i * S.numindvar + 1 + indx2 = (j - 1) * S.numindvar * 1 + (i + 1) * S.numindvar + Sj = @view u[indx1:indx2] + dp = @view du[indx1:indx2] + if !S.isautojacvec + mul!(dp, S.J, Sj) + else + jacobianvec!(dp, S.uf, y, Sj, S.alg, S.jac_config) + end + if j == nus + 1 + # inhomogenous (otherwise homogenous tangent solution) + dp .+= @view S.pJ[:, i] + end + end + end + return nothing end -function dudt_g_dgdu!(dudt, gsave, dgdu, nilssprob::NILSSProblem, y, p, iseg) - @unpack sensealg, diffcache, dg, g, prob = nilssprob - @unpack prob = nilssprob - @unpack jevery, cur_time = diffcache # akin to ``discrete" - - _y = @view y[:,:,iseg] +function forward_sense(prob::NILSSProblem, nilss::NILSS, alg) + #TODO determine a good dtsave (ΔT in paper, see Sec.4.2) + @unpack nus, T_seg, dtsave, vstar, vstar_perp, w, w_perp, R, b, y, dudt, gsave, dgdu, forward_prob, u0, vstar0, w0 = prob + @unpack p, f = forward_prob + @unpack S, sensealg = f + @unpack nseg, nstep = nilss + @unpack numindvar, numparams = S + + # push forward + t1 = forward_prob.tspan[1] + t2 = forward_prob.tspan[1] + T_seg + _prob = ODEForwardSensitivityProblem(S.f, u0, (t1, t2), p, sensealg; nus = nus, w0 = w0, + v0 = vstar0) + + for iseg in 1:nseg + # compute y, w, vstar + # _sol is a numindvar*(1+nus+1) x nstep matrix + + dt = (t2 - t1) / (nstep - 1) + _sol = Array(solve(_prob, alg, saveat = t1:dt:t2)) + + store_y_w_vstar!(y, w, vstar, _sol, nus, numindvar, numparams, iseg) + + # store dudt, objective g (gsave), and its derivative wrt. to u (dgdu) + dudt_g_dgdu!(dudt, gsave, dgdu, prob, y, forward_prob.p, iseg) + + # calculate w_perp, vstar_perp + perp!(w_perp, vstar_perp, w, vstar, dudt, iseg, numparams, nstep, nus) + + # update sense problem + if iseg < nseg + # renormalize at interfaces + renormalize!(R, b, w_perp, vstar_perp, y, vstar, w, iseg, numparams, nus) + t1 = forward_prob.tspan[1] + iseg * T_seg + t2 = forward_prob.tspan[1] + (iseg + 1) * T_seg + _prob = ODEForwardSensitivityProblem(S.f, y[:, 1, iseg + 1], (t1, t2), p, + sensealg; nus = nus, + w0 = vec(w[:, 1, iseg + 1, :]), + v0 = vec(vstar[:, :, 1, iseg + 1])) + end + end +end - for (j,u) in enumerate(eachcol(_y)) - _dgdu = @view dgdu[:,j,iseg] - _dudt = @view dudt[:,j,iseg] +function store_y_w_vstar!(y, w, vstar, sol, nus, numindvar, numparams, iseg) + # fill y + _y = @view y[:, :, iseg] + copyto!(_y, (@view sol[1:numindvar, :])) + + # fill w + # only calculate w one time, w can be reused for each parameter + for j in 1:nus + indx1 = (j - 1) * numindvar * 1 + numindvar + 1 + indx2 = (j - 1) * numindvar * 1 + 2 * numindvar + + _w = @view w[:, :, iseg, j] + copyto!(_w, (@view sol[indx1:indx2, :])) + end - # compute dudt - if isinplace(prob) - prob.f(_dudt,u,p,nothing) - else - copyto!(_dudt,prob.f(u,p,nothing)) + # fill vstar + for i in 1:numparams + indx1 = nus * numindvar * 1 + i * numindvar + 1 + indx2 = nus * numindvar * 1 + (i + 1) * numindvar + _vstar = @view vstar[i, :, :, iseg] + copyto!(_vstar, (@view sol[indx1:indx2, :])) end - # compute objective - gsave[j,iseg] = g(u,p,nothing) - - # compute gradient of objective wrt. state - if jevery!==nothing - # only bump on every jevery entry - # corresponds to (iseg-1)* value of dg - if (j-1) % jevery == 0 - accumulate_cost!(_dgdu, dg, u, p, nothing, sensealg, diffcache, cur_time[]) - cur_time[] += one(jevery) - end - else - # continuous cost function - accumulate_cost!(_dgdu, dg, u, p, nothing, sensealg, diffcache, j) + return nothing +end + +function dudt_g_dgdu!(dudt, gsave, dgdu, nilssprob::NILSSProblem, y, p, iseg) + @unpack sensealg, diffcache, dg, g, prob = nilssprob + @unpack prob = nilssprob + @unpack jevery, cur_time = diffcache # akin to ``discrete" + + _y = @view y[:, :, iseg] + + for (j, u) in enumerate(eachcol(_y)) + _dgdu = @view dgdu[:, j, iseg] + _dudt = @view dudt[:, j, iseg] + + # compute dudt + if isinplace(prob) + prob.f(_dudt, u, p, nothing) + else + copyto!(_dudt, prob.f(u, p, nothing)) + end + + # compute objective + gsave[j, iseg] = g(u, p, nothing) + + # compute gradient of objective wrt. state + if jevery !== nothing + # only bump on every jevery entry + # corresponds to (iseg-1)* value of dg + if (j - 1) % jevery == 0 + accumulate_cost!(_dgdu, dg, u, p, nothing, sensealg, diffcache, cur_time[]) + cur_time[] += one(jevery) + end + else + # continuous cost function + accumulate_cost!(_dgdu, dg, u, p, nothing, sensealg, diffcache, j) + end end - end - jevery !== nothing && (cur_time[] -= one(jevery)) # interface between segments gets two bumps - return nothing + jevery !== nothing && (cur_time[] -= one(jevery)) # interface between segments gets two bumps + return nothing end function perp!(w_perp, vstar_perp, w, vstar, dudt, iseg, numparams, nsteps, nus) - for indx_steps=1:nsteps - _dudt = @view dudt[:,indx_steps,iseg] - for indx_nus=1:nus - _w_perp = @view w_perp[:,indx_steps,iseg,indx_nus] - _w = @view w[:,indx_steps,iseg,indx_nus] - perp!(_w_perp, _w, _dudt) + for indx_steps in 1:nsteps + _dudt = @view dudt[:, indx_steps, iseg] + for indx_nus in 1:nus + _w_perp = @view w_perp[:, indx_steps, iseg, indx_nus] + _w = @view w[:, indx_steps, iseg, indx_nus] + perp!(_w_perp, _w, _dudt) + end + for indx_params in 1:numparams + _vstar_perp = @view vstar_perp[indx_params, :, indx_steps, iseg] + _vstar = @view vstar[indx_params, :, indx_steps, iseg] + perp!(_vstar_perp, _vstar, _dudt) + end end - for indx_params=1:numparams - _vstar_perp = @view vstar_perp[indx_params,:,indx_steps,iseg] - _vstar = @view vstar[indx_params,:,indx_steps,iseg] - perp!(_vstar_perp, _vstar, _dudt) - end - end - return nothing + return nothing end function perp!(v1, v2, v3) - v1 .= v2 - dot(v2, v3)/dot(v3, v3) * v3 + v1 .= v2 - dot(v2, v3) / dot(v3, v3) * v3 end -function renormalize!(R,b,w_perp,vstar_perp,y,vstar,w,iseg,numparams,nus) - for i=1:numparams - _b = @view b[i,(iseg-1)*nus+1:iseg*nus] - _R = @view R[i,iseg,:,:] - _w_perp = @view w_perp[:,end,iseg,:] - _vstar_perp = @view vstar_perp[i,:,end,iseg] - _w = @view w[:,1,iseg+1,:] - _vstar = @view vstar[i,:,1,iseg+1] - - Q_temp, R_temp = qr(_w_perp) - b_tmp = @view (Q_temp'*_vstar_perp)[1:nus] - - copyto!(_b, b_tmp) - copyto!(_R, R_temp) - # set new initial values - copyto!(_w, (@view Q_temp[:,1:nus])) - copyto!(_vstar, _vstar_perp - Q_temp*b_tmp) - end - _yend = @view y[:,end,iseg] - _ystart = @view y[:,1,iseg+1] - copyto!(_ystart, _yend) - - return nothing -end +function renormalize!(R, b, w_perp, vstar_perp, y, vstar, w, iseg, numparams, nus) + for i in 1:numparams + _b = @view b[i, ((iseg - 1) * nus + 1):(iseg * nus)] + _R = @view R[i, iseg, :, :] + _w_perp = @view w_perp[:, end, iseg, :] + _vstar_perp = @view vstar_perp[i, :, end, iseg] + _w = @view w[:, 1, iseg + 1, :] + _vstar = @view vstar[i, :, 1, iseg + 1] + + Q_temp, R_temp = qr(_w_perp) + b_tmp = @view (Q_temp' * _vstar_perp)[1:nus] + + copyto!(_b, b_tmp) + copyto!(_R, R_temp) + # set new initial values + copyto!(_w, (@view Q_temp[:, 1:nus])) + copyto!(_vstar, _vstar_perp - Q_temp * b_tmp) + end + _yend = @view y[:, end, iseg] + _ystart = @view y[:, 1, iseg + 1] + copyto!(_ystart, _yend) + return nothing +end -function compute_Cinv!(Cinv,w_perp,weight,nseg,nus,indxp) - # construct Schur complement of Lagrange multiplier - _weight = @view weight[1,:] - for iseg=1:nseg - _C = @view Cinv[(iseg-1)*nus+1:iseg*nus, (iseg-1)*nus+1:iseg*nus] - for i=1:nus - wi = @view w_perp[:,:,iseg,i] - for j =1:nus - wj = @view w_perp[:,:,iseg,j] - _C[i,j] = sum(wi .* wj * _weight) - end +function compute_Cinv!(Cinv, w_perp, weight, nseg, nus, indxp) + # construct Schur complement of Lagrange multiplier + _weight = @view weight[1, :] + for iseg in 1:nseg + _C = @view Cinv[((iseg - 1) * nus + 1):(iseg * nus), + ((iseg - 1) * nus + 1):(iseg * nus)] + for i in 1:nus + wi = @view w_perp[:, :, iseg, i] + for j in 1:nus + wj = @view w_perp[:, :, iseg, j] + _C[i, j] = sum(wi .* wj * _weight) + end + end + invC = inv(_C) + copyto!(_C, invC) end - invC = inv(_C) - copyto!(_C, invC) - end - return nothing + return nothing end -function compute_d!(d,w_perp,vstar_perp,weight,nseg,nus,indxp) - # construct d - _weight = @view weight[1,:] - for iseg=1:nseg - _d = @view d[(iseg-1)*nus+1:iseg*nus] - vi = @view vstar_perp[indxp,:,:,iseg] - for i=1:nus - wi = @view w_perp[:,:,iseg,i] - _d[i] = sum(wi .* vi * _weight) +function compute_d!(d, w_perp, vstar_perp, weight, nseg, nus, indxp) + # construct d + _weight = @view weight[1, :] + for iseg in 1:nseg + _d = @view d[((iseg - 1) * nus + 1):(iseg * nus)] + vi = @view vstar_perp[indxp, :, :, iseg] + for i in 1:nus + wi = @view w_perp[:, :, iseg, i] + _d[i] = sum(wi .* vi * _weight) + end end - end - return nothing + return nothing end -function compute_B!(B,R,nseg,nus,indxp) - for iseg=1:nseg-1 - _B = @view B[(iseg-1)*nus+1:iseg*nus, (iseg-1)*nus+1:iseg*nus] - _R = @view R[indxp,iseg,:,:] - copyto!(_B, -_R) - # off diagonal one - for i=1:nus - B[(iseg-1)*nus+i, iseg*nus+i] = one(eltype(R)) +function compute_B!(B, R, nseg, nus, indxp) + for iseg in 1:(nseg - 1) + _B = @view B[((iseg - 1) * nus + 1):(iseg * nus), + ((iseg - 1) * nus + 1):(iseg * nus)] + _R = @view R[indxp, iseg, :, :] + copyto!(_B, -_R) + # off diagonal one + for i in 1:nus + B[(iseg - 1) * nus + i, iseg * nus + i] = one(eltype(R)) + end end - end - return nothing + return nothing end -function compute_a!(a,B,Cinv,b,d,indxp) - _b = @view b[indxp,:] +function compute_a!(a, B, Cinv, b, d, indxp) + _b = @view b[indxp, :] - lbd = (-B*Cinv*B') \ (B*Cinv*d + _b) - a .= -Cinv*(B'*lbd + d) - return nothing + lbd = (-B * Cinv * B') \ (B * Cinv * d + _b) + a .= -Cinv * (B' * lbd + d) + return nothing end -function compute_v!(v,v_perp,vstar,vstar_perp,w,w_perp,a,nseg,nus,indxp) - _vstar = @view vstar[indxp,:,:,:] - _vstar_perp = @view vstar_perp[indxp,:,:,:] +function compute_v!(v, v_perp, vstar, vstar_perp, w, w_perp, a, nseg, nus, indxp) + _vstar = @view vstar[indxp, :, :, :] + _vstar_perp = @view vstar_perp[indxp, :, :, :] - copyto!(v, _vstar) - copyto!(v_perp, _vstar_perp) + copyto!(v, _vstar) + copyto!(v_perp, _vstar_perp) - for iseg=1:nseg - vi = @view v[:,:,iseg] - vpi = @view v_perp[:,:,iseg] - for i=1:nus - wi = @view w[:,:,iseg,i] - wpi = @view w_perp[:,:,iseg,i] + for iseg in 1:nseg + vi = @view v[:, :, iseg] + vpi = @view v_perp[:, :, iseg] + for i in 1:nus + wi = @view w[:, :, iseg, i] + wpi = @view w_perp[:, :, iseg, i] - vi .+= a[(iseg-1)*nus+i]*wi - vpi .+= a[(iseg-1)*nus+i]*wpi + vi .+= a[(iseg - 1) * nus + i] * wi + vpi .+= a[(iseg - 1) * nus + i] * wpi + end end - end - - return nothing -end -function compute_xi(ξ,v,dudt,nseg) - for iseg=1:nseg - _v = @view v[:,1,iseg] - _dudt = @view dudt[:,1,iseg] - ξ[iseg,1] = dot(_v,_dudt)/dot(_dudt,_dudt) - - _v = @view v[:,end,iseg] - _dudt = @view dudt[:,end,iseg] - ξ[iseg,2] = dot(_v,_dudt)/dot(_dudt,_dudt) - end - # check if segmentation is chosen correctly - _ξ = ξ[:,1] - all(_ξ.<1e-4) || @warn "Detected a large value of ξ at the beginning of a segment." - return nothing + return nothing end -function accumulate_cost!(_dgdu, dg, u, p, t, sensealg::NILSS, diffcache::NILSSSensitivityFunction, j) - @unpack dg_val, pgpu, pgpu_config, pgpp, pgpp_config = diffcache +function compute_xi(ξ, v, dudt, nseg) + for iseg in 1:nseg + _v = @view v[:, 1, iseg] + _dudt = @view dudt[:, 1, iseg] + ξ[iseg, 1] = dot(_v, _dudt) / dot(_dudt, _dudt) - if dg===nothing - if dg_val isa Tuple - SciMLSensitivity.gradient!(dg_val[1], pgpu, u, sensealg, pgpu_config) - copyto!(_dgdu, dg_val[1]) - else - SciMLSensitivity.gradient!(dg_val, pgpu, u, sensealg, pgpu_config) - copyto!(_dgdu, dg_val) + _v = @view v[:, end, iseg] + _dudt = @view dudt[:, end, iseg] + ξ[iseg, 2] = dot(_v, _dudt) / dot(_dudt, _dudt) end - else - if dg_val isa Tuple - dg[1](dg_val[1],u,p,nothing,j) - @. _dgdu = dg_val[1] + # check if segmentation is chosen correctly + _ξ = ξ[:, 1] + all(_ξ .< 1e-4) || @warn "Detected a large value of ξ at the beginning of a segment." + return nothing +end + +function accumulate_cost!(_dgdu, dg, u, p, t, sensealg::NILSS, + diffcache::NILSSSensitivityFunction, j) + @unpack dg_val, pgpu, pgpu_config, pgpp, pgpp_config = diffcache + + if dg === nothing + if dg_val isa Tuple + SciMLSensitivity.gradient!(dg_val[1], pgpu, u, sensealg, pgpu_config) + copyto!(_dgdu, dg_val[1]) + else + SciMLSensitivity.gradient!(dg_val, pgpu, u, sensealg, pgpu_config) + copyto!(_dgdu, dg_val) + end else - dg(dg_val,u,p,nothing,j) - @. _dgdu = dg_val + if dg_val isa Tuple + dg[1](dg_val[1], u, p, nothing, j) + @. _dgdu = dg_val[1] + else + dg(dg_val, u, p, nothing, j) + @. _dgdu = dg_val + end end - end - return nothing + return nothing end -function shadow_forward(prob::NILSSProblem,alg; sensealg=prob.sensealg) - shadow_forward(prob,sensealg,alg) +function shadow_forward(prob::NILSSProblem, alg; sensealg = prob.sensealg) + shadow_forward(prob, sensealg, alg) end -function shadow_forward(prob::NILSSProblem,sensealg::NILSS,alg) - @unpack nseg, nstep = sensealg - @unpack res, nus, dtsave, vstar, vstar_perp, w, w_perp, R, b, dudt, - gsave, dgdu, forward_prob, weight, Cinv, d, B, a, v, v_perp, ξ = prob - @unpack numindvar, numparams = forward_prob.f.S - - # reset dg pointer - @unpack jevery, cur_time = prob.diffcache - jevery !== nothing && (cur_time[] = one(jevery)) - - # compute vstar, w - forward_sense(prob,sensealg,alg) - - # compute avg objective - gavg = sum(prob.weight*gsave)/((nstep-1)*nseg) - - # reset gradient - res .*= false - - # loop over parameters - for i=1:numparams - compute_Cinv!(Cinv,w_perp,weight,nseg,nus,i) - compute_d!(d,w_perp,vstar_perp,weight,nseg,nus,i) - compute_B!(B,R,nseg,nus,i) - compute_a!(a,B,Cinv,b,d,i) - compute_v!(v,v_perp,vstar,vstar_perp,w,w_perp,a,nseg,nus,i) - compute_xi(ξ,v,dudt,nseg) - - - _weight = @view weight[1,:] - - for iseg=1:nseg - _dgdu = @view dgdu[:,:,iseg] - _v = @view v[:,:,iseg] - res[i] += sum((_v.*_dgdu)*_weight)/((nstep-1)*nseg) - res[i] += ξ[iseg,end]*(gavg-gsave[end,iseg])/(dtsave*(nstep-1)*nseg) +function shadow_forward(prob::NILSSProblem, sensealg::NILSS, alg) + @unpack nseg, nstep = sensealg + @unpack res, nus, dtsave, vstar, vstar_perp, w, w_perp, R, b, dudt, + gsave, dgdu, forward_prob, weight, Cinv, d, B, a, v, v_perp, ξ = prob + @unpack numindvar, numparams = forward_prob.f.S + + # reset dg pointer + @unpack jevery, cur_time = prob.diffcache + jevery !== nothing && (cur_time[] = one(jevery)) + + # compute vstar, w + forward_sense(prob, sensealg, alg) + + # compute avg objective + gavg = sum(prob.weight * gsave) / ((nstep - 1) * nseg) + + # reset gradient + res .*= false + + # loop over parameters + for i in 1:numparams + compute_Cinv!(Cinv, w_perp, weight, nseg, nus, i) + compute_d!(d, w_perp, vstar_perp, weight, nseg, nus, i) + compute_B!(B, R, nseg, nus, i) + compute_a!(a, B, Cinv, b, d, i) + compute_v!(v, v_perp, vstar, vstar_perp, w, w_perp, a, nseg, nus, i) + compute_xi(ξ, v, dudt, nseg) + + _weight = @view weight[1, :] + + for iseg in 1:nseg + _dgdu = @view dgdu[:, :, iseg] + _v = @view v[:, :, iseg] + res[i] += sum((_v .* _dgdu) * _weight) / ((nstep - 1) * nseg) + res[i] += ξ[iseg, end] * (gavg - gsave[end, iseg]) / + (dtsave * (nstep - 1) * nseg) + end end - end - return res + return res end -check_for_g(sensealg::NILSS,g) = (g===nothing && error("To use NILSS, g must be passed as a kwarg to `NILSS(g=g)`.")) +function check_for_g(sensealg::NILSS, g) + (g === nothing && error("To use NILSS, g must be passed as a kwarg to `NILSS(g=g)`.")) +end diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index abe43b103..ec2f6145f 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -1,352 +1,371 @@ -struct ODEQuadratureAdjointSensitivityFunction{C<:AdjointDiffCache,Alg<:QuadratureAdjoint, - uType,SType,fType<:DiffEqBase.AbstractDiffEqFunction} <: SensitivityFunction - diffcache::C - sensealg::Alg - discrete::Bool - y::uType - sol::SType - f::fType +struct ODEQuadratureAdjointSensitivityFunction{C <: AdjointDiffCache, + Alg <: QuadratureAdjoint, + uType, SType, + fType <: DiffEqBase.AbstractDiffEqFunction + } <: SensitivityFunction + diffcache::C + sensealg::Alg + discrete::Bool + y::uType + sol::SType + f::fType end -function ODEQuadratureAdjointSensitivityFunction(g,sensealg,discrete,sol,dg) - diffcache, y = adjointdiffcache(g,sensealg,discrete,sol,dg,sol.prob.f;quad=true) - return ODEQuadratureAdjointSensitivityFunction(diffcache,sensealg,discrete, - y,sol,sol.prob.f) +function ODEQuadratureAdjointSensitivityFunction(g, sensealg, discrete, sol, dg) + diffcache, y = adjointdiffcache(g, sensealg, discrete, sol, dg, sol.prob.f; quad = true) + return ODEQuadratureAdjointSensitivityFunction(diffcache, sensealg, discrete, + y, sol, sol.prob.f) end # u = λ' -function (S::ODEQuadratureAdjointSensitivityFunction)(du,u,p,t) - @unpack sol, discrete = S - f = sol.prob.f +function (S::ODEQuadratureAdjointSensitivityFunction)(du, u, p, t) + @unpack sol, discrete = S + f = sol.prob.f - λ,grad,y,dλ,dgrad,dy = split_states(du,u,t,S) + λ, grad, y, dλ, dgrad, dy = split_states(du, u, t, S) - vecjacobian!(dλ, y, λ, p, t, S) - dλ .*= -one(eltype(λ)) + vecjacobian!(dλ, y, λ, p, t, S) + dλ .*= -one(eltype(λ)) - discrete || accumulate_cost!(dλ, y, p, t, S) - return nothing + discrete || accumulate_cost!(dλ, y, p, t, S) + return nothing end -function split_states(du,u,t,S::ODEQuadratureAdjointSensitivityFunction;update=true) - @unpack y, sol = S +function split_states(du, u, t, S::ODEQuadratureAdjointSensitivityFunction; update = true) + @unpack y, sol = S - if update - if typeof(t) <: ForwardDiff.Dual && eltype(y) <: AbstractFloat - y = sol(t, continuity=:right) - else - sol(y,t, continuity=:right) + if update + if typeof(t) <: ForwardDiff.Dual && eltype(y) <: AbstractFloat + y = sol(t, continuity = :right) + else + sol(y, t, continuity = :right) + end end - end - λ = u - dλ = du + λ = u + dλ = du - λ,nothing,y,dλ,nothing,nothing + λ, nothing, y, dλ, nothing, nothing end # g is either g(t,u,p) or discrete g(t,u,i) -@noinline function ODEAdjointProblem(sol,sensealg::QuadratureAdjoint, - t=nothing, - dg_discrete::DG1=nothing,dg_continuous::DG2=nothing, - g::G=nothing; - callback=CallbackSet()) where {DG1,DG2,G} - - dg_discrete===nothing && dg_continuous===nothing && g===nothing && error("Either `dg_discrete`, `dg_continuous`, or `g` must be specified.") - - @unpack f, p, u0, tspan = sol.prob - terminated = false - if hasfield(typeof(sol),:retcode) - if sol.retcode == :Terminated - tspan = (tspan[1], sol.t[end]) - terminated = true +@noinline function ODEAdjointProblem(sol, sensealg::QuadratureAdjoint, + t = nothing, + dg_discrete::DG1 = nothing, + dg_continuous::DG2 = nothing, + g::G = nothing; + callback = CallbackSet()) where {DG1, DG2, G} + dg_discrete === nothing && dg_continuous === nothing && g === nothing && + error("Either `dg_discrete`, `dg_continuous`, or `g` must be specified.") + + @unpack f, p, u0, tspan = sol.prob + terminated = false + if hasfield(typeof(sol), :retcode) + if sol.retcode == :Terminated + tspan = (tspan[1], sol.t[end]) + terminated = true + end end - end - tspan = reverse(tspan) - - discrete = (t !== nothing && dg_continuous === nothing) - - len = length(u0) - λ = similar(u0, len) - λ .= false - sense = ODEQuadratureAdjointSensitivityFunction(g,sensealg,discrete,sol,dg_continuous) - - init_cb = (discrete || dg_discrete!==nothing) # && tspan[1] == t[end] - z0 = vec(zero(λ)) - cb, duplicate_iterator_times = generate_callbacks(sense, dg_discrete, λ, t, tspan[2], callback, init_cb, terminated) - - jac_prototype = sol.prob.f.jac_prototype - adjoint_jac_prototype = !sense.discrete || jac_prototype === nothing ? nothing : copy(jac_prototype') - - original_mm = sol.prob.f.mass_matrix - if original_mm === I || original_mm === (I,I) - odefun = ODEFunction(sense, jac_prototype=adjoint_jac_prototype) - else - odefun = ODEFunction(sense, mass_matrix=sol.prob.f.mass_matrix', jac_prototype=adjoint_jac_prototype) - end - return ODEProblem(odefun,z0,tspan,p,callback=cb) -end + tspan = reverse(tspan) -struct AdjointSensitivityIntegrand{pType,uType,lType,rateType,S,AS,PF,PJC,PJT,DGP,G} - sol::S - adj_sol::AS - p::pType - y::uType - λ::lType - pf::PF - f_cache::rateType - pJ::PJT - paramjac_config::PJC - sensealg::QuadratureAdjoint - dgdp_cache::DGP - dgdp::G -end + discrete = (t !== nothing && dg_continuous === nothing) -function AdjointSensitivityIntegrand(sol,adj_sol,sensealg,dgdp=nothing) - prob = sol.prob - @unpack f, p, tspan, u0 = prob - numparams = length(p) - y = zero(sol.prob.u0) - λ = zero(adj_sol.prob.u0) - # we need to alias `y` - f_cache = zero(y) - f_cache .= false - isautojacvec = get_jacvec(sensealg) - - dgdp_cache = dgdp === nothing ? nothing : zero(p) - - if sensealg.autojacvec isa ReverseDiffVJP - tape = if DiffEqBase.isinplace(prob) - ReverseDiff.GradientTape((y, prob.p, [tspan[2]])) do u,p,t - du1 = similar(p, size(u)) - du1 .= false - f(du1,u,p,first(t)) - return vec(du1) - end + len = length(u0) + λ = similar(u0, len) + λ .= false + sense = ODEQuadratureAdjointSensitivityFunction(g, sensealg, discrete, sol, + dg_continuous) + + init_cb = (discrete || dg_discrete !== nothing) # && tspan[1] == t[end] + z0 = vec(zero(λ)) + cb, duplicate_iterator_times = generate_callbacks(sense, dg_discrete, λ, t, tspan[2], + callback, init_cb, terminated) + + jac_prototype = sol.prob.f.jac_prototype + adjoint_jac_prototype = !sense.discrete || jac_prototype === nothing ? nothing : + copy(jac_prototype') + + original_mm = sol.prob.f.mass_matrix + if original_mm === I || original_mm === (I, I) + odefun = ODEFunction(sense, jac_prototype = adjoint_jac_prototype) else - ReverseDiff.GradientTape((y, prob.p, [tspan[2]])) do u,p,t - vec(f(u,p,first(t))) - end + odefun = ODEFunction(sense, mass_matrix = sol.prob.f.mass_matrix', + jac_prototype = adjoint_jac_prototype) end - if compile_tape(sensealg.autojacvec) - paramjac_config = ReverseDiff.compile(tape) + return ODEProblem(odefun, z0, tspan, p, callback = cb) +end + +struct AdjointSensitivityIntegrand{pType, uType, lType, rateType, S, AS, PF, PJC, PJT, DGP, + G} + sol::S + adj_sol::AS + p::pType + y::uType + λ::lType + pf::PF + f_cache::rateType + pJ::PJT + paramjac_config::PJC + sensealg::QuadratureAdjoint + dgdp_cache::DGP + dgdp::G +end + +function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing) + prob = sol.prob + @unpack f, p, tspan, u0 = prob + numparams = length(p) + y = zero(sol.prob.u0) + λ = zero(adj_sol.prob.u0) + # we need to alias `y` + f_cache = zero(y) + f_cache .= false + isautojacvec = get_jacvec(sensealg) + + dgdp_cache = dgdp === nothing ? nothing : zero(p) + + if sensealg.autojacvec isa ReverseDiffVJP + tape = if DiffEqBase.isinplace(prob) + ReverseDiff.GradientTape((y, prob.p, [tspan[2]])) do u, p, t + du1 = similar(p, size(u)) + du1 .= false + f(du1, u, p, first(t)) + return vec(du1) + end + else + ReverseDiff.GradientTape((y, prob.p, [tspan[2]])) do u, p, t + vec(f(u, p, first(t))) + end + end + if compile_tape(sensealg.autojacvec) + paramjac_config = ReverseDiff.compile(tape) + else + paramjac_config = tape + end + pf = nothing + pJ = nothing + elseif sensealg.autojacvec isa EnzymeVJP + paramjac_config = zero(y), zero(y) + pf = let f = f.f + if DiffEqBase.isinplace(prob) && prob isa RODEProblem + function (out, u, _p, t, W) + f(out, u, _p, t, W) + nothing + end + elseif DiffEqBase.isinplace(prob) + function (out, u, _p, t) + f(out, u, _p, t) + nothing + end + elseif !DiffEqBase.isinplace(prob) && prob isa RODEProblem + function (out, u, _p, t, W) + out .= f(u, _p, t, W) + nothing + end + else + !DiffEqBase.isinplace(prob) + function (out, u, _p, t) + out .= f(u, _p, t) + nothing + end + end + end + pJ = nothing + elseif isautojacvec # Zygote + paramjac_config = nothing + pf = nothing + pJ = nothing else - paramjac_config = tape + pf = DiffEqBase.ParamJacobianWrapper(f, tspan[1], y) + pJ = similar(u0, length(u0), numparams) + paramjac_config = build_param_jac_config(sensealg, pf, y, p) end - pf = nothing - pJ = nothing - elseif sensealg.autojacvec isa EnzymeVJP - paramjac_config = zero(y),zero(y) - pf = let f = f.f - if DiffEqBase.isinplace(prob) && prob isa RODEProblem - function (out,u,_p,t,W) - f(out, u, _p, t, W) - nothing - end - elseif DiffEqBase.isinplace(prob) - function (out,u,_p,t) - f(out, u, _p, t) - nothing - end - elseif !DiffEqBase.isinplace(prob) && prob isa RODEProblem - function (out,u,_p,t,W) - out .= f(u, _p, t, W) - nothing - end - else !DiffEqBase.isinplace(prob) - function (out,u,_p,t) - out .= f(u, _p, t) - nothing - end - end - end - pJ = nothing - elseif isautojacvec # Zygote - paramjac_config = nothing - pf = nothing - pJ = nothing - else - pf = DiffEqBase.ParamJacobianWrapper(f,tspan[1],y) - pJ = similar(u0,length(u0),numparams) - paramjac_config = build_param_jac_config(sensealg,pf,y,p) - end - AdjointSensitivityIntegrand(sol,adj_sol,p,y,λ,pf,f_cache,pJ,paramjac_config,sensealg,dgdp_cache,dgdp) + AdjointSensitivityIntegrand(sol, adj_sol, p, y, λ, pf, f_cache, pJ, paramjac_config, + sensealg, dgdp_cache, dgdp) end -function (S::AdjointSensitivityIntegrand)(out,t) - @unpack y, λ, pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol, adj_sol = S - f = sol.prob.f - sol(y,t) - adj_sol(λ,t) - isautojacvec = get_jacvec(sensealg) - # y is aliased - - if !isautojacvec - if DiffEqBase.has_paramjac(f) - f.paramjac(pJ,y,p,t) # Calculate the parameter Jacobian into pJ - else - pf.t = t - jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config) +function (S::AdjointSensitivityIntegrand)(out, t) + @unpack y, λ, pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol, adj_sol = S + f = sol.prob.f + sol(y, t) + adj_sol(λ, t) + isautojacvec = get_jacvec(sensealg) + # y is aliased + + if !isautojacvec + if DiffEqBase.has_paramjac(f) + f.paramjac(pJ, y, p, t) # Calculate the parameter Jacobian into pJ + else + pf.t = t + jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config) + end + mul!(out', λ', pJ) + elseif sensealg.autojacvec isa ReverseDiffVJP + tape = paramjac_config + tu, tp, tt = ReverseDiff.input_hook(tape) + output = ReverseDiff.output_hook(tape) + ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls + ReverseDiff.unseed!(tp) + ReverseDiff.unseed!(tt) + ReverseDiff.value!(tu, y) + ReverseDiff.value!(tp, p) + ReverseDiff.value!(tt, [t]) + ReverseDiff.forward_pass!(tape) + ReverseDiff.increment_deriv!(output, λ) + ReverseDiff.reverse_pass!(tape) + copyto!(vec(out), ReverseDiff.deriv(tp)) + elseif sensealg.autojacvec isa ZygoteVJP + _dy, back = Zygote.pullback(p) do p + vec(f(y, p, t)) + end + tmp = back(λ) + out[:] .= vec(tmp[1]) + elseif sensealg.autojacvec isa EnzymeVJP + tmp3, tmp4 = paramjac_config + tmp4 .= λ + out .= 0 + Enzyme.autodiff(pf, Enzyme.Duplicated(tmp3, tmp4), + y, Enzyme.Duplicated(p, out), t) end - mul!(out',λ',pJ) - elseif sensealg.autojacvec isa ReverseDiffVJP - tape = paramjac_config - tu, tp, tt = ReverseDiff.input_hook(tape) - output = ReverseDiff.output_hook(tape) - ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls - ReverseDiff.unseed!(tp) - ReverseDiff.unseed!(tt) - ReverseDiff.value!(tu, y) - ReverseDiff.value!(tp, p) - ReverseDiff.value!(tt, [t]) - ReverseDiff.forward_pass!(tape) - ReverseDiff.increment_deriv!(output, λ) - ReverseDiff.reverse_pass!(tape) - copyto!(vec(out), ReverseDiff.deriv(tp)) - elseif sensealg.autojacvec isa ZygoteVJP - _dy, back = Zygote.pullback(p) do p - vec(f(y, p, t)) + + # TODO: Add tracker? + + if S.dgdp !== nothing + S.dgdp(dgdp_cache, y, p, t) + out .+= dgdp_cache end - tmp = back(λ) - out[:] .= vec(tmp[1]) - elseif sensealg.autojacvec isa EnzymeVJP - tmp3,tmp4 = paramjac_config - tmp4 .= λ - out .= 0 - Enzyme.autodiff(pf,Enzyme.Duplicated(tmp3,tmp4), - y,Enzyme.Duplicated(p, out),t) - end - - # TODO: Add tracker? - - if S.dgdp !== nothing - S.dgdp(dgdp_cache, y, p, t) - out .+= dgdp_cache - end - out' + out' end function (S::AdjointSensitivityIntegrand)(t) - out = similar(S.p) - S(out,t) + out = similar(S.p) + S(out, t) end -function _adjoint_sensitivities(sol,sensealg::QuadratureAdjoint,alg;t=nothing, - dg_discrete=nothing,dg_continuous=nothing, - g=nothing, - abstol=sensealg.abstol,reltol=sensealg.reltol, +function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothing, + dg_discrete = nothing, dg_continuous = nothing, + g = nothing, + abstol = sensealg.abstol, reltol = sensealg.reltol, callback = CallbackSet(), kwargs...) - dgdu, dgdp = dg_continuous isa Tuple ? dg_continuous : (dg_continuous, nothing) - adj_prob = ODEAdjointProblem(sol,sensealg,t,dg_discrete,dgdu,g; callback) - adj_sol = solve(adj_prob,alg;abstol=abstol,reltol=reltol, - save_everystep=true,save_start=true,kwargs...) - - p = sol.prob.p - if p === nothing || p === DiffEqBase.NullParameters() - return -adj_sol[end],nothing - else - integrand = AdjointSensitivityIntegrand(sol,adj_sol,sensealg,dgdp) - - if t === nothing - res,err = quadgk(integrand,sol.prob.tspan[1],sol.prob.tspan[2], - atol=abstol,rtol=reltol) + dgdu, dgdp = dg_continuous isa Tuple ? dg_continuous : (dg_continuous, nothing) + adj_prob = ODEAdjointProblem(sol, sensealg, t, dg_discrete, dgdu, g; callback) + adj_sol = solve(adj_prob, alg; abstol = abstol, reltol = reltol, + save_everystep = true, save_start = true, kwargs...) + + p = sol.prob.p + if p === nothing || p === DiffEqBase.NullParameters() + return -adj_sol[end], nothing else - res = zero(integrand.p)' - - if callback!==nothing - cur_time = length(t) - dλ = similar(integrand.λ) - dλ .*= false - dgrad = similar(res) - dgrad .*= false - end - - # correction for end interval. - if t[end] != sol.prob.tspan[2] - res .+= quadgk(integrand, t[end], sol.prob.tspan[end], - atol=abstol, rtol=reltol)[1] - end - - for i in length(t)-1:-1:1 - res .+= quadgk(integrand,t[i],t[i+1], - atol=abstol,rtol=reltol)[1] - if t[i]==t[i+1] - for cb in callback.discrete_callbacks - if t[i] ∈ cb.affect!.event_times - integrand = update_integrand_and_dgrad(res,sensealg,cb,integrand,adj_prob,sol,dg_discrete,dλ,dgrad,t[i],cur_time) + integrand = AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp) + + if t === nothing + res, err = quadgk(integrand, sol.prob.tspan[1], sol.prob.tspan[2], + atol = abstol, rtol = reltol) + else + res = zero(integrand.p)' + + if callback !== nothing + cur_time = length(t) + dλ = similar(integrand.λ) + dλ .*= false + dgrad = similar(res) + dgrad .*= false + end + + # correction for end interval. + if t[end] != sol.prob.tspan[2] + res .+= quadgk(integrand, t[end], sol.prob.tspan[end], + atol = abstol, rtol = reltol)[1] + end + + for i in (length(t) - 1):-1:1 + res .+= quadgk(integrand, t[i], t[i + 1], + atol = abstol, rtol = reltol)[1] + if t[i] == t[i + 1] + for cb in callback.discrete_callbacks + if t[i] ∈ cb.affect!.event_times + integrand = update_integrand_and_dgrad(res, sensealg, cb, + integrand, adj_prob, sol, + dg_discrete, dλ, dgrad, + t[i], cur_time) + end + end + for cb in callback.continuous_callbacks + if t[i] ∈ cb.affect!.event_times || + t[i] ∈ cb.affect_neg!.event_times + integrand = update_integrand_and_dgrad(res, sensealg, cb, + integrand, adj_prob, sol, + dg_discrete, dλ, dgrad, + t[i], cur_time) + end + end + end + callback !== nothing && (cur_time -= one(cur_time)) end - end - for cb in callback.continuous_callbacks - if t[i] ∈ cb.affect!.event_times || t[i] ∈ cb.affect_neg!.event_times - integrand = update_integrand_and_dgrad(res,sensealg,cb,integrand,adj_prob,sol,dg_discrete,dλ,dgrad,t[i],cur_time) + # correction for start interval + if t[1] != sol.prob.tspan[1] + res .+= quadgk(integrand, sol.prob.tspan[1], t[1], + atol = abstol, rtol = reltol)[1] end - end end - callback!==nothing && (cur_time -= one(cur_time)) - end - # correction for start interval - if t[1] != sol.prob.tspan[1] - res .+= quadgk(integrand,sol.prob.tspan[1],t[1], - atol=abstol,rtol=reltol)[1] - end + return adj_sol[end], res end - return adj_sol[end], res - end end -function update_p_integrand(integrand::AdjointSensitivityIntegrand,p) - @unpack sol, adj_sol, y, λ, pf, f_cache, pJ, paramjac_config, sensealg, dgdp_cache, dgdp = integrand - AdjointSensitivityIntegrand(sol,adj_sol,p,y,λ,pf,f_cache,pJ,paramjac_config,sensealg,dgdp_cache,dgdp) +function update_p_integrand(integrand::AdjointSensitivityIntegrand, p) + @unpack sol, adj_sol, y, λ, pf, f_cache, pJ, paramjac_config, sensealg, dgdp_cache, dgdp = integrand + AdjointSensitivityIntegrand(sol, adj_sol, p, y, λ, pf, f_cache, pJ, paramjac_config, + sensealg, dgdp_cache, dgdp) end -function update_integrand_and_dgrad(res,sensealg::QuadratureAdjoint,cb,integrand,adj_prob,sol,dg,dλ,dgrad,t,cur_time) +function update_integrand_and_dgrad(res, sensealg::QuadratureAdjoint, cb, integrand, + adj_prob, sol, dg, dλ, dgrad, t, cur_time) + indx, pos_neg = get_indx(cb, t) + tprev = get_tprev(cb, indx, pos_neg) + + wp = let tprev = tprev, pos_neg = pos_neg + function (dp, p, u, t) + _affect! = get_affect!(cb, pos_neg) + fakeinteg = FakeIntegrator([x for x in u], [x for x in p], t, tprev) + _affect!(fakeinteg) + dp .= fakeinteg.p + end + end - indx, pos_neg = get_indx(cb, t) - tprev = get_tprev(cb,indx,pos_neg) + _p = similar(integrand.p, size(integrand.p)) + wp(_p, integrand.p, integrand.y, t) - wp = let tprev=tprev, pos_neg=pos_neg - function (dp,p,u,t) - _affect! = get_affect!(cb,pos_neg) - fakeinteg = FakeIntegrator([x for x in u],[x for x in p],t,tprev) - _affect!(fakeinteg) - dp .= fakeinteg.p + if _p != integrand.p + fakeSp = CallbackSensitivityFunction(wp, sensealg, adj_prob.f.f.diffcache, sol.prob) + #vjp with Jacobin given by dw/dp before event and vector given by grad + vecjacobian!(res, integrand.p, res, integrand.y, t, fakeSp; + dgrad = nothing, dy = nothing) + integrand = update_p_integrand(integrand, _p) end - end - - _p = similar(integrand.p, size(integrand.p)) - wp(_p,integrand.p,integrand.y,t) - - if _p != integrand.p - fakeSp = CallbackSensitivityFunction(wp,sensealg,adj_prob.f.f.diffcache,sol.prob) - #vjp with Jacobin given by dw/dp before event and vector given by grad - vecjacobian!(res, integrand.p, res, integrand.y, t, fakeSp; - dgrad=nothing, dy=nothing) - integrand = update_p_integrand(integrand,_p) - end - - w = let tprev=tprev, pos_neg=pos_neg - function (du,u,p,t) - _affect! = get_affect!(cb,pos_neg) - fakeinteg = FakeIntegrator([x for x in u],[x for x in p],t,tprev) - _affect!(fakeinteg) - du .= vec(fakeinteg.u) + + w = let tprev = tprev, pos_neg = pos_neg + function (du, u, p, t) + _affect! = get_affect!(cb, pos_neg) + fakeinteg = FakeIntegrator([x for x in u], [x for x in p], t, tprev) + _affect!(fakeinteg) + du .= vec(fakeinteg.u) + end end - end - - # Create a fake sensitivity function to do the vjps needs to be done - # to account for parameter dependence of affect function - fakeS = CallbackSensitivityFunction(w,sensealg,adj_prob.f.f.diffcache,sol.prob) - if dg !== nothing # discrete cost - dg(dλ, integrand.y, integrand.p, t, cur_time) - else - error("Please provide `dg_discrete` to use adjoint_sensitivities with `QuadratureAdjoint()` and callbacks.") - end - - # account for implicit events - - @. dλ = -dλ-integrand.λ - vecjacobian!(dλ, integrand.y, dλ, integrand.p, t, fakeS; dgrad=dgrad) - res .-= dgrad - return integrand + + # Create a fake sensitivity function to do the vjps needs to be done + # to account for parameter dependence of affect function + fakeS = CallbackSensitivityFunction(w, sensealg, adj_prob.f.f.diffcache, sol.prob) + if dg !== nothing # discrete cost + dg(dλ, integrand.y, integrand.p, t, cur_time) + else + error("Please provide `dg_discrete` to use adjoint_sensitivities with `QuadratureAdjoint()` and callbacks.") + end + + # account for implicit events + + @. dλ = -dλ - integrand.λ + vecjacobian!(dλ, integrand.y, dλ, integrand.p, t, fakeS; dgrad = dgrad) + res .-= dgrad + return integrand end diff --git a/src/reversediff.jl b/src/reversediff.jl index e6339fa76..74c502363 100644 --- a/src/reversediff.jl +++ b/src/reversediff.jl @@ -4,9 +4,18 @@ DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0 -DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray, t0) = u0 -DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = u0 -DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = u0 +function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, + p::ReverseDiff.TrackedArray, t0) + u0 +end +function DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, + p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) + u0 +end +function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, + p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) + u0 +end DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0) DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0) @@ -14,52 +23,82 @@ DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = elt @inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t) where {N} sqrt(sum(abs2, DiffEqBase.value(u)) / length(u)) end -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal,N}, t) where {N} - sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N}, + t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) end -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal,N}, t) where {N} - sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N}, + t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, t) + abs(DiffEqBase.value(u)) end -@inline DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, t) = abs(DiffEqBase.value(u)) # Support TrackedReal time, don't drop tracking on the adaptivity there -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t::ReverseDiff.TrackedReal) where {N} +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, + t::ReverseDiff.TrackedReal) where {N} sqrt(sum(abs2, u) / length(u)) end -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal,N}, t::ReverseDiff.TrackedReal) where {N} - sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u)) +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N}, + t::ReverseDiff.TrackedReal) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / + length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N}, + t::ReverseDiff.TrackedReal) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / + length(u)) end -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal,N}, t::ReverseDiff.TrackedReal) where {N} - sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u)) +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, + t::ReverseDiff.TrackedReal) + abs(u) end -@inline DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, t::ReverseDiff.TrackedReal) = abs(u) -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, args...; kwargs...) +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, + p::ReverseDiff.TrackedArray, args...; kwargs...) ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0, p::ReverseDiff.TrackedArray, args...; kwargs...) +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm, + Nothing}, u0, p::ReverseDiff.TrackedArray, + args...; kwargs...) ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0::ReverseDiff.TrackedArray, p, args...; kwargs...) +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, p, + args...; kwargs...) ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -@inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG, proto::ReverseDiff.TrackedArray) +@inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG, + proto::ReverseDiff.TrackedArray) ReverseDiff.track(convert.(eltype(proto.value), randn(rng, size(proto)))) end -@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG, rand_vec::Array{<:ReverseDiff.TrackedReal}) +@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG, + rand_vec::Array{<:ReverseDiff.TrackedReal + }) rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec)))) end -@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG, rand_vec::AbstractArray{<:ReverseDiff.TrackedReal}) +@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG, + rand_vec::AbstractArray{ + <:ReverseDiff.TrackedReal + }) rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec)))) end # Required becase ReverseDiff.@grad function DiffEqBase.solve_up is not supported! import DiffEqBase: solve_up -ReverseDiff.@grad function solve_up(prob,sensealg,u0,p,args...;kwargs...) - out = DiffEqBase._solve_adjoint(prob,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p), - SciMLBase.ReverseDiffOriginator(),args...;kwargs...) - Array(out[1]),out[2] -end \ No newline at end of file +ReverseDiff.@grad function solve_up(prob, sensealg, u0, p, args...; kwargs...) + out = DiffEqBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0), + ReverseDiff.value(p), + SciMLBase.ReverseDiffOriginator(), args...; kwargs...) + Array(out[1]), out[2] +end diff --git a/src/sde_tools.jl b/src/sde_tools.jl index c2b8f6c52..71593dbb3 100644 --- a/src/sde_tools.jl +++ b/src/sde_tools.jl @@ -1,81 +1,80 @@ # for Ito / Stratonovich conversion -struct StochasticTransformedFunction{pType,fType<:DiffEqBase.AbstractDiffEqFunction,gType,noiseType,cfType} <: TransformedFunction - prob::pType - f::fType - g::gType - gtmp::noiseType - inplace::Bool - corfunc_analytical::cfType +struct StochasticTransformedFunction{pType, fType <: DiffEqBase.AbstractDiffEqFunction, + gType, noiseType, cfType} <: TransformedFunction + prob::pType + f::fType + g::gType + gtmp::noiseType + inplace::Bool + corfunc_analytical::cfType end +function StochasticTransformedFunction(sol, f, g, corfunc_analytical = nothing) + @unpack prob = sol -function StochasticTransformedFunction(sol,f,g, corfunc_analytical=nothing) - @unpack prob = sol - - if StochasticDiffEq.is_diagonal_noise(prob) - gtmp = copy(sol.u[end]) - else - gtmp = similar(prob.p, size(prob.noise_rate_prototype)) - end + if StochasticDiffEq.is_diagonal_noise(prob) + gtmp = copy(sol.u[end]) + else + gtmp = similar(prob.p, size(prob.noise_rate_prototype)) + end - return StochasticTransformedFunction(prob,f,g,gtmp,DiffEqBase.isinplace(prob),corfunc_analytical) + return StochasticTransformedFunction(prob, f, g, gtmp, DiffEqBase.isinplace(prob), + corfunc_analytical) end - -function (Tfunc::StochasticTransformedFunction)(du,u,p,t) - @unpack gtmp, f, g, corfunc_analytical = Tfunc - - ducor = similar(u, size(u)) - - if corfunc_analytical !== nothing - corfunc_analytical(ducor,u,p,t) - else - tape = ReverseDiff.GradientTape((u, p, [t])) do uloc,ploc,tloc - du1 = similar(uloc, size(gtmp)) - g(du1,uloc,ploc,first(tloc)) - return vec(du1) +function (Tfunc::StochasticTransformedFunction)(du, u, p, t) + @unpack gtmp, f, g, corfunc_analytical = Tfunc + + ducor = similar(u, size(u)) + + if corfunc_analytical !== nothing + corfunc_analytical(ducor, u, p, t) + else + tape = ReverseDiff.GradientTape((u, p, [t])) do uloc, ploc, tloc + du1 = similar(uloc, size(gtmp)) + g(du1, uloc, ploc, first(tloc)) + return vec(du1) + end + tu, tp, tt = ReverseDiff.input_hook(tape) + output = ReverseDiff.output_hook(tape) + + ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls + ReverseDiff.unseed!(tp) + ReverseDiff.unseed!(tt) + + ReverseDiff.value!(tu, u) + ReverseDiff.value!(tp, p) + ReverseDiff.value!(tt, [t]) + + ReverseDiff.forward_pass!(tape) + ReverseDiff.increment_deriv!(output, vec(ReverseDiff.value(output))) + ReverseDiff.reverse_pass!(tape) + + ReverseDiff.deriv(tu) + ReverseDiff.pull_value!(output) + copyto!(vec(ducor), ReverseDiff.deriv(tu)) end - tu, tp, tt = ReverseDiff.input_hook(tape) - output = ReverseDiff.output_hook(tape) - - ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls - ReverseDiff.unseed!(tp) - ReverseDiff.unseed!(tt) - ReverseDiff.value!(tu, u) - ReverseDiff.value!(tp, p) - ReverseDiff.value!(tt, [t]) + f(du, u, p, t) - ReverseDiff.forward_pass!(tape) - ReverseDiff.increment_deriv!(output, vec(ReverseDiff.value(output))) - ReverseDiff.reverse_pass!(tape) - - ReverseDiff.deriv(tu) - ReverseDiff.pull_value!(output) - copyto!(vec(ducor), ReverseDiff.deriv(tu)) - end - - f(du,u,p,t) - - @. du = du - ducor - return nothing + @. du = du - ducor + return nothing end - -function (Tfunc::StochasticTransformedFunction)(u,p,t) - @unpack f, g, corfunc_analytical = Tfunc - #ducor = vecjacobian(u, p, t, Tfunc) - - if corfunc_analytical !== nothing - ducor = corfunc_analytical(u,p,t) - else - _dy, back = Zygote.pullback(u, p) do uloc, ploc - vec(g(uloc, ploc, t)) +function (Tfunc::StochasticTransformedFunction)(u, p, t) + @unpack f, g, corfunc_analytical = Tfunc + #ducor = vecjacobian(u, p, t, Tfunc) + + if corfunc_analytical !== nothing + ducor = corfunc_analytical(u, p, t) + else + _dy, back = Zygote.pullback(u, p) do uloc, ploc + vec(g(uloc, ploc, t)) + end + ducor, _ = back(_dy) end - ducor, _ = back(_dy) - end - du = f(u,p,t) + du = f(u, p, t) - du = @. du - ducor - return du + du = @. du - ducor + return du end diff --git a/src/second_order.jl b/src/second_order.jl index ba5f1654a..2667dbb1c 100644 --- a/src/second_order.jl +++ b/src/second_order.jl @@ -1,17 +1,18 @@ -function _second_order_sensitivities(loss,prob,alg,sensealg::ForwardDiffOverAdjoint, - args...;kwargs...) - ForwardDiff.jacobian(prob.p) do p - x = Zygote.gradient(p) do _p - loss(solve(prob,alg,args...;p=_p,sensealg=sensealg.adjalg,kwargs...)) - end - first(x) - end -end - -function _second_order_sensitivity_product(loss,v,prob,alg,sensealg::ForwardDiffOverAdjoint, - args...;kwargs...) - - θ = ForwardDiff.Dual.(prob.p,v) - _loss = p -> loss(solve(prob,alg,args...;p=p,sensealg=sensealg.adjalg,kwargs...)) - getindex.(ForwardDiff.partials.(Zygote.gradient(_loss,θ)[1]),1) -end +function _second_order_sensitivities(loss, prob, alg, sensealg::ForwardDiffOverAdjoint, + args...; kwargs...) + ForwardDiff.jacobian(prob.p) do p + x = Zygote.gradient(p) do _p + loss(solve(prob, alg, args...; p = _p, sensealg = sensealg.adjalg, kwargs...)) + end + first(x) + end +end + +function _second_order_sensitivity_product(loss, v, prob, alg, + sensealg::ForwardDiffOverAdjoint, + args...; kwargs...) + θ = ForwardDiff.Dual.(prob.p, v) + _loss = p -> loss(solve(prob, alg, args...; p = p, sensealg = sensealg.adjalg, + kwargs...)) + getindex.(ForwardDiff.partials.(Zygote.gradient(_loss, θ)[1]), 1) +end diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 4d8502f70..e55cfba03 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1,1075 +1,1126 @@ -SensitivityAlg(args...;kwargs...) = @error("The SensitivtyAlg choice mechanism was completely overhauled. Please consult the local sensitivity documentation for more information") - -abstract type AbstractForwardSensitivityAlgorithm{CS,AD,FDT} <: DiffEqBase.AbstractSensitivityAlgorithm{CS,AD,FDT} end -abstract type AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} <: DiffEqBase.AbstractSensitivityAlgorithm{CS,AD,FDT} end -abstract type AbstractSecondOrderSensitivityAlgorithm{CS,AD,FDT} <: DiffEqBase.AbstractSensitivityAlgorithm{CS,AD,FDT} end -abstract type AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} <: DiffEqBase.AbstractSensitivityAlgorithm{CS,AD,FDT} end - -""" -ForwardSensitivity{CS,AD,FDT} <: AbstractForwardSensitivityAlgorithm{CS,AD,FDT} - -An implementation of continuous forward sensitivity analysis for propagating -derivatives by solving the extended ODE. When used within adjoint differentiation -(i.e. via Zygote), this will cause forward differentiation of the `solve` call -within the reverse-mode automatic differentiation environment. - -## Constructor - -```julia -function ForwardSensitivity(; - chunk_size=0,autodiff=true, - diff_type=Val{:central}, - autojacvec=autodiff, - autojacmat=false) -``` - -## Keyword Arguments - -* `autodiff`: Use automatic differentiation in the internal sensitivity algorithm - computations. Default is `true`. -* `chunk_size`: Chunk size for forward mode differentiation if full Jacobians are - built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic - choice of chunk size. -* `autojacvec`: Calculate the Jacobian-vector product via automatic - differentiation with special seeding. -* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian - if the full Jacobian is required with `autodiff=false`. - -Further details: - -- If `autodiff=true` and `autojacvec=true`, then the one chunk `J*v` forward-mode - directional derivative calculation trick is used to compute the product without - constructing the Jacobian (via ForwardDiff.jl). -- If `autodiff=false` and `autojacvec=true`, then the numerical direction derivative - trick `(f(x+epsilon*v)-f(x))/epsilon` is used to compute `J*v` without constructing - the Jacobian. -- If `autodiff=true` and `autojacvec=false`, then the Jacobian is constructed via - chunked forward-mode automatic differentiation (via ForwardDiff.jl). -- If `autodiff=false` and `autojacvec=false`, then the Jacobian is constructed via - finite differences via FiniteDiff.jl. - -## SciMLProblem Support - -This `sensealg` only supports `ODEProblem`s without callbacks (events). -""" -struct ForwardSensitivity{CS,AD,FDT} <: AbstractForwardSensitivityAlgorithm{CS,AD,FDT} - autojacvec::Bool - autojacmat::Bool -end -Base.@pure function ForwardSensitivity(; - chunk_size=0,autodiff=true, - diff_type=Val{:central}, - autojacvec=autodiff, - autojacmat=false) - autojacvec && autojacmat && error("Choose either Jacobian matrix products or Jacobian vector products, - autojacmat and autojacvec cannot both be true") - ForwardSensitivity{chunk_size,autodiff,diff_type}(autojacvec,autojacmat) -end - -""" -ForwardDiffSensitivity{CS,CTS} <: AbstractForwardSensitivityAlgorithm{CS,Nothing,Nothing} - -An implementation of discrete forward sensitivity analysis through ForwardDiff.jl. -When used within adjoint differentiation (i.e. via Zygote), this will cause forward -differentiation of the `solve` call within the reverse-mode automatic differentiation -environment. - -## Constructor - -```julia -ForwardDiffSensitivity(;chunk_size=0,convert_tspan=nothing) -``` - -## Keyword Arguments - -* `chunk_size`: the chunk size used by ForwardDiff for computing the Jacobian, i.e. the - number of simultaneous columns computed. -* `convert_tspan`: whether to convert time to also be `Dual` valued. By default this is - `nothing` which will only convert if callbacks are found. Conversion is required in order - to accurately differentiate callbacks (hybrid equations). - -## SciMLProblem Support - -This `sensealg` supports any `SciMLProblem`s, provided that the solver algorithms is -`SciMLBase.isautodifferentiable`. Note that `ForwardDiffSensitivity` can -accurately differentiate code with callbacks only when `convert_tspan=true`. -""" -struct ForwardDiffSensitivity{CS,CTS} <: AbstractForwardSensitivityAlgorithm{CS,Nothing,Nothing} -end -Base.@pure function ForwardDiffSensitivity(;chunk_size=0,convert_tspan=nothing) - ForwardDiffSensitivity{chunk_size,convert_tspan}() -end - -""" -BacksolveAdjoint{CS,AD,FDT,VJP} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} - -An implementation of adjoint sensitivity analysis using a backwards solution of the ODE. -By default this algorithm will use the values from the forward pass to perturb the -backwards solution to the correct spot, allowing reduced memory (O(1) memory). Checkpointing -stabilization is included for additional numerical stability over the naive implementation. - -## Constructor - -```julia -BacksolveAdjoint(;chunk_size=0,autodiff=true, - diff_type=Val{:central}, - autojacvec=nothing, - checkpointing=true, noisemixing=false) -``` - -## Keyword Arguments - -* `autodiff`: Use automatic differentiation for constructing the Jacobian - if the Jacobian needs to be constructed. Defaults to `true`. -* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are - built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic - choice of chunk size. -* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian - if the full Jacobian is required with `autodiff=false`. -* `autojacvec`: Calculate the vector-Jacobian product (`J'*v`) via automatic - differentiation with special seeding. The default is `true`. The total set - of choices are: - - `false`: the Jacobian is constructed via FiniteDiff.jl - - `true`: the Jacobian is constructed via ForwardDiff.jl - - `TrackerVJP`: Uses Tracker.jl for the vjp. - - `ZygoteVJP`: Uses Zygote.jl for the vjp. - - `EnzymeVJP`: Uses Enzyme.jl for the vjp. - - `ReverseDiffVJP(compile=false)`: Uses ReverseDiff.jl for the vjp. `compile` - is a boolean for whether to precompile the tape, which should only be done - if there are no branches (`if` or `while` statements) in the `f` function. -* `checkpointing`: whether checkpointing is enabled for the reverse pass. Defaults - to `true`. -* `noisemixing`: Handle noise processes that are not of the form `du[i] = f(u[i])`. - For example, to compute the sensitivities of an SDE with diagonal diffusion - ```julia - function g_mixing!(du,u,p,t) - du[1] = p[3]*u[1] + p[4]*u[2] - du[2] = p[3]*u[1] + p[4]*u[2] - nothing - end - ``` - correctly, `noisemixing=true` must be enabled. The default is `false`. - -For more details on the vjp choices, please consult the sensitivity algorithms -documentation page or the docstrings of the vjp types. - -## Applicability of Backsolve and Caution - -When `BacksolveAdjoint` is applicable, it is a fast method and requires the least memory. -However, one must be cautious because not all ODEs are stable under backwards integration -by the majority of ODE solvers. An example of such an equation is the Lorenz equation. -Notice that if one solves the Lorenz equation forward and then in reverse with any -adaptive time step and non-reversible integrator, then the backwards solution diverges -from the forward solution. As a quick demonstration: - -```julia -using Sundials -function lorenz(du,u,p,t) - du[1] = 10.0*(u[2]-u[1]) - du[2] = u[1]*(28.0-u[3]) - u[2] - du[3] = u[1]*u[2] - (8/3)*u[3] -end -u0 = [1.0;0.0;0.0] -tspan = (0.0,100.0) -prob = ODEProblem(lorenz,u0,tspan) -sol = solve(prob,Tsit5(),reltol=1e-12,abstol=1e-12) -prob2 = ODEProblem(lorenz,sol[end],(100.0,0.0)) -sol = solve(prob,Tsit5(),reltol=1e-12,abstol=1e-12) -@show sol[end]-u0 #[-3.22091, -1.49394, 21.3435] -``` - -Thus one should check the stability of the backsolve on their type of problem before -enabling this method. Additionally, using checkpointing with backsolve can be a -low memory way to stabilize it. - -For more details on this topic, see -[Stiff Neural Ordinary Differential Equations](https://aip.scitation.org/doi/10.1063/5.0060697). - -## Checkpointing - -To improve the numerical stability of the reverse pass, `BacksolveAdjoint` includes a checkpointing -feature. If `sol.u` is a time series, then whenever a time `sol.t` is hit while reversing, a callback -will replace the reversing ODE portion with `sol.u[i]`. This nudges the solution back onto the appropriate -trajectory and reduces the numerical caused by drift. - -## SciMLProblem Support - -This `sensealg` only supports `ODEProblem`s, `SDEProblem`s, and `RODEProblem`s. This `sensealg` supports -callback functions (events). - -## References - -ODE: - Rackauckas, C. and Ma, Y. and Martensen, J. and Warner, C. and Zubov, K. and Supekar, - R. and Skinner, D. and Ramadhana, A. and Edelman, A., Universal Differential Equations - for Scientific Machine Learning, arXiv:2001.04385 - - Hindmarsh, A. C. and Brown, P. N. and Grant, K. E. and Lee, S. L. and Serban, R. - and Shumaker, D. E. and Woodward, C. S., SUNDIALS: Suite of nonlinear and - differential/algebraic equation solvers, ACM Transactions on Mathematical - Software (TOMS), 31, pp:363–396 (2005) - - Chen, R.T.Q. and Rubanova, Y. and Bettencourt, J. and Duvenaud, D. K., - Neural ordinary differential equations. In Advances in neural information processing - systems, pp. 6571–6583 (2018) - - Pontryagin, L. S. and Mishchenko, E.F. and Boltyanskii, V.G. and Gamkrelidze, R.V. - The mathematical theory of optimal processes. Routledge, (1962) - - Rackauckas, C. and Ma, Y. and Dixit, V. and Guo, X. and Innes, M. and Revels, J. - and Nyberg, J. and Ivaturi, V., A comparison of automatic differentiation and - continuous sensitivity analysis for derivatives of differential equation solutions, - arXiv:1812.01892 - -DAE: - Cao, Y. and Li, S. and Petzold, L. and Serban, R., Adjoint sensitivity analysis - for differential-algebraic equations: The adjoint DAE system and its numerical - solution, SIAM journal on scientific computing 24 pp: 1076-1089 (2003) - -SDE: - Gobet, E. and Munos, R., Sensitivity Analysis Using Ito-Malliavin Calculus and - Martingales, and Application to Stochastic Optimal Control, - SIAM Journal on control and optimization, 43, pp. 1676-1713 (2005) - - Li, X. and Wong, T.-K. L.and Chen, R. T. Q. and Duvenaud, D., - Scalable Gradients for Stochastic Differential Equations, - PMLR 108, pp. 3870-3882 (2020), http://proceedings.mlr.press/v108/li20i.html -""" -struct BacksolveAdjoint{CS,AD,FDT,VJP} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} - autojacvec::VJP - checkpointing::Bool - noisemixing::Bool -end -Base.@pure function BacksolveAdjoint(;chunk_size=0,autodiff=true, - diff_type=Val{:central}, - autojacvec=nothing, - checkpointing=true, noisemixing=false) - BacksolveAdjoint{chunk_size,autodiff,diff_type,typeof(autojacvec)}(autojacvec,checkpointing,noisemixing) -end -setvjp(sensealg::BacksolveAdjoint{CS,AD,FDT,Nothing}, vjp) where {CS,AD,FDT} = - BacksolveAdjoint{CS,AD,FDT,typeof(vjp)}(vjp,sensealg.checkpointing, - sensealg.noisemixing) - -""" -InterpolatingAdjoint{CS,AD,FDT,VJP} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} - -An implementation of adjoint sensitivity analysis which uses the interpolation of -the forward solution for the reverse solve vector-Jacobian products. By -default it requires a dense solution of the forward pass and will internally -ignore saving arguments during the gradient calculation. When checkpointing is -enabled it will only require the memory to interpolate between checkpoints. - -## Constructor - -```julia -function InterpolatingAdjoint(;chunk_size=0,autodiff=true, - diff_type=Val{:central}, - autojacvec=nothing, - checkpointing=false, noisemixing=false) -``` - -## Keyword Arguments - -* `autodiff`: Use automatic differentiation for constructing the Jacobian - if the Jacobian needs to be constructed. Defaults to `true`. -* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are - built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic - choice of chunk size. -* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian - if the full Jacobian is required with `autodiff=false`. -* `autojacvec`: Calculate the vector-Jacobian product (`J'*v`) via automatic - differentiation with special seeding. The default is `true`. The total set - of choices are: - - `false`: the Jacobian is constructed via FiniteDiff.jl - - `true`: the Jacobian is constructed via ForwardDiff.jl - - `TrackerVJP`: Uses Tracker.jl for the vjp. - - `ZygoteVJP`: Uses Zygote.jl for the vjp. - - `EnzymeVJP`: Uses Enzyme.jl for the vjp. - - `ReverseDiffVJP(compile=false)`: Uses ReverseDiff.jl for the vjp. `compile` - is a boolean for whether to precompile the tape, which should only be done - if there are no branches (`if` or `while` statements) in the `f` function. -* `checkpointing`: whether checkpointing is enabled for the reverse pass. Defaults - to `true`. -* `noisemixing`: Handle noise processes that are not of the form `du[i] = f(u[i])`. - For example, to compute the sensitivities of an SDE with diagonal diffusion - ```julia - function g_mixing!(du,u,p,t) - du[1] = p[3]*u[1] + p[4]*u[2] - du[2] = p[3]*u[1] + p[4]*u[2] - nothing - end - ``` - correctly, `noisemixing=true` must be enabled. The default is `false`. - -For more details on the vjp choices, please consult the sensitivity algorithms -documentation page or the docstrings of the vjp types. - -## Checkpointing - -To reduce the memory usage of the reverse pass, `InterpolatingAdjoint` includes a checkpointing -feature. If `sol` is `dense`, checkpointing is ignored and the continuous solution is used for -calculating `u(t)` at arbitrary time points. If `checkpointing=true` and `sol` is not `dense`, -then dense intervals between `sol.t[i]` and `sol.t[i+1]` are reconstructed on-demand for calculating -`u(t)` at arbitrary time points. This reduces the total memory requirement to only the cost of -holding the dense solution over the largest time interval (in terms of number of required steps). -The total compute cost is no more than double the original forward compute cost. - -## SciMLProblem Support - -This `sensealg` only supports `ODEProblem`s, `SDEProblem`s, and `RODEProblem`s. This `sensealg` -supports callbacks (events). - -## References - - Rackauckas, C. and Ma, Y. and Martensen, J. and Warner, C. and Zubov, K. and Supekar, - R. and Skinner, D. and Ramadhana, A. and Edelman, A., Universal Differential Equations - for Scientific Machine Learning, arXiv:2001.04385 - - Hindmarsh, A. C. and Brown, P. N. and Grant, K. E. and Lee, S. L. and Serban, R. - and Shumaker, D. E. and Woodward, C. S., SUNDIALS: Suite of nonlinear and - differential/algebraic equation solvers, ACM Transactions on Mathematical - Software (TOMS), 31, pp:363–396 (2005) - - Rackauckas, C. and Ma, Y. and Dixit, V. and Guo, X. and Innes, M. and Revels, J. - and Nyberg, J. and Ivaturi, V., A comparison of automatic differentiation and - continuous sensitivity analysis for derivatives of differential equation solutions, - arXiv:1812.01892 -""" -struct InterpolatingAdjoint{CS,AD,FDT,VJP} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} - autojacvec::VJP - checkpointing::Bool - noisemixing::Bool -end -Base.@pure function InterpolatingAdjoint(;chunk_size=0,autodiff=true, - diff_type=Val{:central}, - autojacvec=nothing, - checkpointing=false,noisemixing=false) - InterpolatingAdjoint{chunk_size,autodiff,diff_type,typeof(autojacvec)}(autojacvec,checkpointing,noisemixing) -end -setvjp(sensealg::InterpolatingAdjoint{CS,AD,FDT,Nothing},vjp) where {CS,AD,FDT} = - InterpolatingAdjoint{CS,AD,FDT,typeof(vjp)}(vjp,sensealg.checkpointing, - sensealg.noisemixing) - -""" -QuadratureAdjoint{CS,AD,FDT,VJP} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} - -An implementation of adjoint sensitivity analysis which develops a full -continuous solution of the reverse solve in order to perform a post-ODE -quadrature. This method requires the the dense solution and will ignore -saving arguments during the gradient calculation. The tolerances in the -constructor control the inner quadrature. The inner quadrature uses a -ReverseDiff vjp if autojacvec, and `compile=false` by default but can -compile the tape under the same circumstances as `ReverseDiffVJP`. - -This method is O(n^3 + p) for stiff / implicit equations (as opposed to the -O((n+p)^3) scaling of BacksolveAdjoint and InterpolatingAdjoint), and thus -is much more compute efficient. However, it requires holding a dense reverse -pass and is thus memory intensive. - -## Constructor - -```julia -function QuadratureAdjoint(;chunk_size=0,autodiff=true, - diff_type=Val{:central}, - autojacvec=nothing,abstol=1e-6, - reltol=1e-3,compile=false) -``` - -## Keyword Arguments - -* `autodiff`: Use automatic differentiation for constructing the Jacobian - if the Jacobian needs to be constructed. Defaults to `true`. -* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are - built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic - choice of chunk size. -* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian - if the full Jacobian is required with `autodiff=false`. -* `autojacvec`: Calculate the vector-Jacobian product (`J'*v`) via automatic - differentiation with special seeding. The default is `true`. The total set - of choices are: - - `false`: the Jacobian is constructed via FiniteDiff.jl - - `true`: the Jacobian is constructed via ForwardDiff.jl - - `TrackerVJP`: Uses Tracker.jl for the vjp. - - `ZygoteVJP`: Uses Zygote.jl for the vjp. - - `EnzymeVJP`: Uses Enzyme.jl for the vjp. - - `ReverseDiffVJP(compile=false)`: Uses ReverseDiff.jl for the vjp. `compile` - is a boolean for whether to precompile the tape, which should only be done - if there are no branches (`if` or `while` statements) in the `f` function. -* `abstol`: absolute tolerance for the quadrature calculation -* `reltol`: relative tolerance for the quadrature calculation -* `compile`: whether to compile the vjp calculation for the integrand calculation. - See `ReverseDiffVJP` for more details. - -For more details on the vjp choices, please consult the sensitivity algorithms -documentation page or the docstrings of the vjp types. - -## SciMLProblem Support - -This `sensealg` only supports `ODEProblem`s. This `sensealg` supports events (callbacks). - -## References - - Rackauckas, C. and Ma, Y. and Martensen, J. and Warner, C. and Zubov, K. and Supekar, - R. and Skinner, D. and Ramadhana, A. and Edelman, A., Universal Differential Equations - for Scientific Machine Learning, arXiv:2001.04385 - - Hindmarsh, A. C. and Brown, P. N. and Grant, K. E. and Lee, S. L. and Serban, R. - and Shumaker, D. E. and Woodward, C. S., SUNDIALS: Suite of nonlinear and - differential/algebraic equation solvers, ACM Transactions on Mathematical - Software (TOMS), 31, pp:363–396 (2005) - - Rackauckas, C. and Ma, Y. and Dixit, V. and Guo, X. and Innes, M. and Revels, J. - and Nyberg, J. and Ivaturi, V., A comparison of automatic differentiation and - continuous sensitivity analysis for derivatives of differential equation solutions, - arXiv:1812.01892 - - Kim, S., Ji, W., Deng, S., Ma, Y., & Rackauckas, C. (2021). Stiff neural ordinary - differential equations. Chaos: An Interdisciplinary Journal of Nonlinear Science, 31(9), 093122. -""" -struct QuadratureAdjoint{CS,AD,FDT,VJP} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} - autojacvec::VJP - abstol::Float64 - reltol::Float64 -end -Base.@pure function QuadratureAdjoint(;chunk_size=0,autodiff=true, - diff_type=Val{:central}, - autojacvec=nothing,abstol=1e-6, - reltol=1e-3) - QuadratureAdjoint{chunk_size,autodiff,diff_type,typeof(autojacvec)}(autojacvec,abstol,reltol) -end -setvjp(sensealg::QuadratureAdjoint{CS,AD,FDT,Nothing},vjp) where {CS,AD,FDT} = - QuadratureAdjoint{CS,AD,FDT,typeof(vjp)}(vjp,sensealg.abstol, - sensealg.reltol) - -""" -TrackerAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} - -An implementation of discrete adjoint sensitivity analysis -using the Tracker.jl tracing-based AD. Supports in-place functions through -an Array of Structs formulation, and supports out of place through struct of -arrays. - -## Constructor - -```julia -TrackerAdjoint() -``` - -## SciMLProblem Support - -This `sensealg` supports any `DEProblem` if the algorithm is `SciMLBase.isautodifferentiable` -Compatible with a limited subset of `AbstractArray` types for `u0`, including `CuArrays`. - -!!! warn - - TrackerAdjoint is incompatible with Stiff ODE solvers using forward-mode automatic - differentiation for the Jacobians. Thus for example, `TRBDF2()` will error. Instead, - use `autodiff=false`, i.e. `TRBDF2(autodiff=false)`. This will only remove the - forward-mode automatic differentiation of the Jacobian construction, not the reverse-mode - AD usage, and thus performance will still be nearly the same, though Jacobian accuracy - may suffer which could cause more steps to be required. -""" -struct TrackerAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} end - -""" -ReverseDiffAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} - -An implementation of discrete adjoint sensitivity analysis using the ReverseDiff.jl -tracing-based AD. Supports in-place functions through an Array of Structs formulation, -and supports out of place through struct of arrays. - -## Constructor - -```julia -ReverseDiffAdjoint() -``` - -## SciMLProblem Support - -This `sensealg` supports any `DEProblem` if the algorithm is `SciMLBase.isautodifferentiable`. -Requires that the state variables are CPU-based `Array` types. -""" -struct ReverseDiffAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} end - -""" -ZygoteAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} - -An implementation of discrete adjoint sensitivity analysis -using the Zygote.jl source-to-source AD directly on the differential equation -solver. - -## Constructor - -```julia -ZygoteAdjoint() -``` - -## SciMLProblem Support - -Currently fails on almost every solver. -""" -struct ZygoteAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} end - -""" -ForwardLSS{CS,AD,FDT,RType,gType} <: AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} - -An implementation of the discrete, forward-mode -[least squares shadowing](https://arxiv.org/abs/1204.0159) (LSS) method. LSS replaces -the ill-conditioned initial value probem (`ODEProblem`) for chaotic systems by a -well-conditioned least-squares problem. This allows for computing sensitivities of -long-time averaged quantities with respect to the parameters of the `ODEProblem`. The -computational cost of LSS scales as (number of states x number of time steps). Converges -to the correct sensitivity at a rate of `T^(-1/2)`, where `T` is the time of the trajectory. -See `NILSS()` and `NILSAS()` for a more efficient non-intrusive formulation. - -## Constructor - -```julia -ForwardLSS(; - chunk_size=0,autodiff=true, - diff_type=Val{:central}, - LSSregularizer=TimeDilation(10.0,0.0,0.0), - g=nothing) -``` - -## Keyword Arguments - -* `autodiff`: Use automatic differentiation for constructing the Jacobian - if the Jacobian needs to be constructed. Defaults to `true`. -* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are - built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic - choice of chunk size. -* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian - if the full Jacobian is required with `autodiff=false`. -* `LSSregularizer`: Using `LSSregularizer`, one can choose between three different - regularization routines. The default choice is `TimeDilation(10.0,0.0,0.0)`. - - `CosWindowing()`: cos windowing of the time grid, i.e. the time grid (saved - time steps) is transformed using a cosine. - - `Cos2Windowing()`: cos^2 windowing of the time grid. - - `TimeDilation(alpha::Number,t0skip::Number,t1skip::Number)`: Corresponds to - a time dilation. `alpha` controls the weight. `t0skip` and `t1skip` indicate - the times truncated at the beginnning and end of the trajectory, respectively. -* `g`: instantaneous objective function of the long-time averaged objective. - -## SciMLProblem Support - -This `sensealg` only supports `ODEProblem`s. This `sensealg` does not support -events (callbacks). This `sensealg` assumes that the objective is a long-time averaged -quantity and ergodic, i.e. the time evolution of the system behaves qualitatively the -same over infinite time independent of the specified initial conditions, such that only -the sensitivity with respect to the parameters is of interest. - -## References - -Wang, Q., Hu, R., and Blonigan, P. Least squares shadowing sensitivity analysis of -chaotic limit cycle oscillations. Journal of Computational Physics, 267, 210-224 (2014). - -Wang, Q., Convergence of the Least Squares Shadowing Method for Computing Derivative of Ergodic -Averages, SIAM Journal on Numerical Analysis, 52, 156–170 (2014). - -Blonigan, P., Gomez, S., Wang, Q., Least Squares Shadowing for sensitivity analysis of turbulent -fluid flows, in: 52nd Aerospace Sciences Meeting, 1–24 (2014). -""" -struct ForwardLSS{CS,AD,FDT,RType,gType} <: AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} - LSSregularizer::RType - g::gType -end -Base.@pure function ForwardLSS(; - chunk_size=0, autodiff=true, - diff_type=Val{:central}, - LSSregularizer=TimeDilation(10.0,0.0,0.0), - g=nothing) - - ForwardLSS{chunk_size,autodiff,diff_type,typeof(LSSregularizer),typeof(g)}(LSSregularizer, g) -end - -""" -AdjointLSS{CS,AD,FDT,RType,gType} <: AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} - -An implementation of the discrete, adjoint-mode -[least square shadowing](https://arxiv.org/abs/1204.0159) method. LSS replaces -the ill-conditioned initial value probem (`ODEProblem`) for chaotic systems by a -well-conditioned least-squares problem. This allows for computing sensitivities of -long-time averaged quantities with respect to the parameters of the `ODEProblem`. The -computational cost of LSS scales as (number of states x number of time steps). Converges -to the correct sensitivity at a rate of `T^(-1/2)`, where `T` is the time of the trajectory. -See `NILSS()` and `NILSAS()` for a more efficient non-intrusive formulation. - -## Constructor - -```julia -AdjointLSS(; - chunk_size=0,autodiff=true, - diff_type=Val{:central}, - LSSRegularizer=TimeDilation(10.0,0.0,0.0), - g=nothing) -``` - -## Keyword Arguments - -* `autodiff`: Use automatic differentiation for constructing the Jacobian - if the Jacobian needs to be constructed. Defaults to `true`. -* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are - built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic - choice of chunk size. -* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian - if the full Jacobian is required with `autodiff=false`. -* `LSSregularizer`: Using `LSSregularizer`, one can choose between different - regularization routines. The default choice is `TimeDilation(10.0,0.0,0.0)`. - - `TimeDilation(alpha::Number,t0skip::Number,t1skip::Number)`: Corresponds to - a time dilation. `alpha` controls the weight. `t0skip` and `t1skip` indicate - the times truncated at the beginnning and end of the trajectory, respectively. - The default value for `t0skip` and `t1skip` is `zero(alpha)`. -* `g`: instantaneous objective function of the long-time averaged objective. - -## SciMLProblem Support - -This `sensealg` only supports `ODEProblem`s. This `sensealg` does not support -events (callbacks). This `sensealg` assumes that the objective is a long-time averaged -quantity and ergodic, i.e. the time evolution of the system behaves qualitatively the -same over infinite time independent of the specified initial conditions, such that only -the sensitivity with respect to the parameters is of interest. - -## References - -Wang, Q., Hu, R., and Blonigan, P. Least squares shadowing sensitivity analysis of -chaotic limit cycle oscillations. Journal of Computational Physics, 267, 210-224 (2014). -""" -struct AdjointLSS{CS,AD,FDT,RType,gType} <: AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} - LSSregularizer::RType - g::gType -end -Base.@pure function AdjointLSS(; - chunk_size=0, autodiff=true, - diff_type=Val{:central}, - LSSregularizer=TimeDilation(10.0, 0.0, 0.0), - g=nothing) - AdjointLSS{chunk_size,autodiff,diff_type,typeof(LSSregularizer),typeof(g)}(LSSregularizer, g) -end - -abstract type AbstractLSSregularizer end -abstract type AbstractCosWindowing <: AbstractLSSregularizer end -struct CosWindowing <: AbstractCosWindowing end -struct Cos2Windowing <: AbstractCosWindowing end - -""" -TimeDilation{T1<:Number} <: AbstractLSSregularizer - -A regularization method for `LSS`. See `?LSS` for -additional information and other methods. - -## Constructor - -```julia -TimeDilation(alpha; - t0skip=zero(alpha), - t1skip=zero(alpha)) -``` -""" -struct TimeDilation{T1<:Number} <: AbstractLSSregularizer - alpha::T1 # alpha: weight of the time dilation term in LSS. - t0skip::T1 - t1skip::T1 -end -function TimeDilation(alpha,t0skip=zero(alpha),t1skip=zero(alpha)) - TimeDilation{typeof(alpha)}(alpha,t0skip,t1skip) -end -""" -struct NILSS{CS,AD,FDT,RNG,nType,gType} <: AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} - -An implementation of the forward-mode, continuous -[non-intrusive least squares shadowing](https://arxiv.org/abs/1611.00880) method. `NILSS` -allows for computing sensitivities of long-time averaged quantities with respect to the -parameters of an `ODEProblem` by constraining the computation to the unstable subspace. -`NILSS` employs the continuous-time `ForwardSensitivity` method as tangent solver. To -avoid an exponential blow-up of the (homogenous and inhomogenous) tangent solutions, -the trajectory should be divided into sufficiently small segments, where the tangent solutions -are rescaled on the interfaces. The computational and memory cost of NILSS scale with -the number of unstable (positive) Lyapunov exponents (instead of the number of states as -in the LSS method). `NILSS` avoids the explicit construction of the Jacobian at each time -step and thus should generally be preferred (for large system sizes) over `ForwardLSS`. - -## Constructor - -```julia -NILSS(nseg, nstep; nus = nothing, - rng = Xorshifts.Xoroshiro128Plus(rand(UInt64)), - chunk_size=0,autodiff=true, - diff_type=Val{:central}, - autojacvec=autodiff, - g=nothing) -``` - -## Arguments - -* `nseg`: Number of segments on full time interval on the attractor. -* `nstep`: number of steps on each segment. - -## Keyword Arguments - -* `nus`: Dimension of the unstable subspace. Default is `nothing`. `nus` must be - smaller or equal to the state dimension (`length(u0)`). With the default choice, - `nus = length(u0) - 1` will be set at compile time. -* `rng`: (Pseudo) random number generator. Used for initializing the homogenous - tangent states (`w`). Default is `Xorshifts.Xoroshiro128Plus(rand(UInt64))`. -* `autodiff`: Use automatic differentiation in the internal sensitivity algorithm - computations. Default is `true`. -* `chunk_size`: Chunk size for forward mode differentiation if full Jacobians are - built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic - choice of chunk size. -* `autojacvec`: Calculate the Jacobian-vector product via automatic - differentiation with special seeding. -* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian - if the full Jacobian is required with `autodiff=false`. -* `g`: instantaneous objective function of the long-time averaged objective. - -## SciMLProblem Support - -This `sensealg` only supports `ODEProblem`s. This `sensealg` does not support -events (callbacks). This `sensealg` assumes that the objective is a long-time averaged -quantity and ergodic, i.e. the time evolution of the system behaves qualitatively the -same over infinite time independent of the specified initial conditions, such that only -the sensitivity with respect to the parameters is of interest. - -## References -Ni, A., Blonigan, P. J., Chater, M., Wang, Q., Zhang, Z., Sensitivity analy- -sis on chaotic dynamical system by Non-Intrusive Least Square Shadowing -(NI-LSS), in: 46th AIAA Fluid Dynamics Conference, AIAA AVIATION Forum (AIAA 2016-4399), -American Institute of Aeronautics and Astronautics, 1–16 (2016). - -Ni, A., and Wang, Q. Sensitivity analysis on chaotic dynamical systems by Non-Intrusive -Least Squares Shadowing (NILSS). Journal of Computational Physics 347, 56-77 (2017). -""" -struct NILSS{CS,AD,FDT,RNG,nType,gType} <: AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} - rng::RNG - nseg::Int - nstep::Int - nus::nType - autojacvec::Bool - g::gType -end -Base.@pure function NILSS(nseg, nstep; nus=nothing, rng=Xorshifts.Xoroshiro128Plus(rand(UInt64)), - chunk_size=0, autodiff=true, - diff_type=Val{:central}, - autojacvec=autodiff, - g=nothing -) - NILSS{chunk_size,autodiff,diff_type,typeof(rng),typeof(nus),typeof(g)}(rng,nseg,nstep,nus,autojacvec,g) -end - -""" -NILSAS{CS,AD,FDT,RNG,SENSE,gType} <: AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} - -An implementation of the adjoint-mode, continuous -[non-intrusive adjoint least squares shadowing](https://arxiv.org/abs/1801.08674) method. -`NILSAS` allows for computing sensitivities of long-time averaged quantities with respect -to the parameters of an `ODEProblem` by constraining the computation to the unstable subspace. -`NILSAS` employs SciMLSensitivity.jl's continuous adjoint sensitivity methods on each segment -to compute (homogenous and inhomogenous) adjoint solutions. To avoid an exponential blow-up -of the adjoint solutions, the trajectory should be divided into sufficiently small segments, -where the adjoint solutions are rescaled on the interfaces. The computational and memory cost -of NILSAS scale with the number of unstable, adjoint Lyapunov exponents (instead of the number -of states as in the LSS method). `NILSAS` avoids the explicit construction of the Jacobian at -each time step and thus should generally be preferred (for large system sizes) over `AdjointLSS`. -`NILSAS` is favourable over `NILSS` for many parameters because NILSAS computes the gradient -with respect to multiple parameters with negligible additional cost. - -## Constructor - -```julia -NILSAS(nseg, nstep, M=nothing; rng = Xorshifts.Xoroshiro128Plus(rand(UInt64)), - adjoint_sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP()), - chunk_size=0,autodiff=true, - diff_type=Val{:central}, - g=nothing - ) -``` - -## Arguments - -* `nseg`: Number of segments on full time interval on the attractor. -* `nstep`: number of steps on each segment. -* `M`: number of homogenous adjoint solutions. This number must be bigger or equal - than the number of (positive, adjoint) Lyapunov exponents. Default is `nothing`. - -## Keyword Arguments - -* `rng`: (Pseudo) random number generator. Used for initializing the terminate - conditions of the homogenous adjoint states (`w`). Default is `Xorshifts.Xoroshiro128Plus(rand(UInt64))`. -* `adjoint_sensealg`: Continuous adjoint sensitivity method to compute homogenous - and inhomogenous adjoint solutions on each segment. Default is `BacksolveAdjoint(autojacvec=ReverseDiffVJP())`. - * `autojacvec`: Calculate the vector-Jacobian product (`J'*v`) via automatic - differentiation with special seeding. The default is `true`. The total set - of choices are: - - `false`: the Jacobian is constructed via FiniteDiff.jl - - `true`: the Jacobian is constructed via ForwardDiff.jl - - `TrackerVJP`: Uses Tracker.jl for the vjp. - - `ZygoteVJP`: Uses Zygote.jl for the vjp. - - `EnzymeVJP`: Uses Enzyme.jl for the vjp. - - `ReverseDiffVJP(compile=false)`: Uses ReverseDiff.jl for the vjp. `compile` - is a boolean for whether to precompile the tape, which should only be done - if there are no branches (`if` or `while` statements) in the `f` function. -* `autodiff`: Use automatic differentiation for constructing the Jacobian - if the Jacobian needs to be constructed. Defaults to `true`. -* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are - built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic - choice of chunk size. -* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian - if the full Jacobian is required with `autodiff=false`. -* `g`: instantaneous objective function of the long-time averaged objective. - -## SciMLProblem Support - -This `sensealg` only supports `ODEProblem`s. This `sensealg` does not support -events (callbacks). This `sensealg` assumes that the objective is a long-time averaged -quantity and ergodic, i.e. the time evolution of the system behaves qualitatively the -same over infinite time independent of the specified initial conditions, such that only -the sensitivity with respect to the parameters is of interest. - -## References - -Ni, A., and Talnikar, C., Adjoint sensitivity analysis on chaotic dynamical systems -by Non-Intrusive Least Squares Adjoint Shadowing (NILSAS). Journal of Computational -Physics 395, 690-709 (2019). -""" -struct NILSAS{CS,AD,FDT,RNG,SENSE,gType} <: AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} - rng::RNG - adjoint_sensealg::SENSE - M::Int - nseg::Int - nstep::Int - g::gType -end -Base.@pure function NILSAS(nseg, nstep, M=nothing; rng=Xorshifts.Xoroshiro128Plus(rand(UInt64)), - adjoint_sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP()), - chunk_size=0, autodiff=true, - diff_type=Val{:central}, - g=nothing -) - - # integer dimension of the unstable subspace - M === nothing && error("Please provide an `M` with `M >= nus + 1`, where nus is the number of unstable covariant Lyapunov vectors.") - - NILSAS{chunk_size,autodiff,diff_type,typeof(rng),typeof(adjoint_sensealg),typeof(g)}(rng, adjoint_sensealg, M, - nseg, nstep, g) -end - -""" -SteadyStateAdjoint{CS,AD,FDT,VJP,LS} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} - -An implementation of the adjoint differentiation of a nonlinear solve. Uses the -implicit function theorem to directly compute the derivative of the solution to -``f(u,p) = 0`` with respect to `p`. - -## Constructor - -```julia -SteadyStateAdjoint(;chunk_size = 0, autodiff = true, - diff_type = Val{:central}, - autojacvec = autodiff, linsolve = nothing) -``` - -## Keyword Arguments - -* `autodiff`: Use automatic differentiation for constructing the Jacobian - if the Jacobian needs to be constructed. Defaults to `true`. -* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are - built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic - choice of chunk size. -* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian - if the full Jacobian is required with `autodiff=false`. -* `autojacvec`: Calculate the vector-Jacobian product (`J'*v`) via automatic - differentiation with special seeding. The default is `nothing`. The total set - of choices are: - - `false`: the Jacobian is constructed via FiniteDiff.jl - - `true`: the Jacobian is constructed via ForwardDiff.jl - - `TrackerVJP`: Uses Tracker.jl for the vjp. - - `ZygoteVJP`: Uses Zygote.jl for the vjp. - - `EnzymeVJP`: Uses Enzyme.jl for the vjp. - - `ReverseDiffVJP(compile=false)`: Uses ReverseDiff.jl for the vjp. `compile` - is a boolean for whether to precompile the tape, which should only be done - if there are no branches (`if` or `while` statements) in the `f` function. -* `linsolve`: the linear solver used in the adjoint solve. Defaults to `nothing`, - which uses a polyalgorithm to attempt to automatically choose an efficient - algorithm. - -For more details on the vjp choices, please consult the sensitivity algorithms -documentation page or the docstrings of the vjp types. - -## References - -Johnson, S. G., Notes on Adjoint Methods for 18.336, Online at -http://math.mit.edu/stevenj/18.336/adjoint.pdf (2007) -""" -struct SteadyStateAdjoint{CS,AD,FDT,VJP,LS} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} - autojacvec::VJP - linsolve::LS -end - -Base.@pure function SteadyStateAdjoint(;chunk_size = 0, autodiff = true, diff_type = Val{:central}, - autojacvec = nothing, linsolve = nothing) - SteadyStateAdjoint{chunk_size,autodiff,diff_type,typeof(autojacvec),typeof(linsolve)}(autojacvec,linsolve) -end -setvjp(sensealg::SteadyStateAdjoint{CS,AD,FDT,LS},vjp) where {CS,AD,FDT,LS} = - SteadyStateAdjoint{CS,AD,FDT,typeof(vjp),LS}(vjp,sensealg.linsolve) - -abstract type VJPChoice end - -""" -ZygoteVJP <: VJPChoice - -Uses Zygote.jl to compute vector-Jacobian products. Tends to be the fastest VJP method if the -ODE/DAE/SDE/DDE is written with mostly vectorized functions (like neural networks and other -layers from Flux.jl) and the `f` functions is given out-of-place. If the `f` function is -in-place, then `Zygote.Buffer` arrays are used internally which can greatly reduce the -performance of the VJP method. - -## Constructor - -```julia -ZygoteVJP(;allow_nothing=false) -``` - -Keyword arguments: - -* `allow_nothing`: whether `nothing`s should be implicitly converted to zeros. In Zygote, - the derivative of a function with respect to `p` which does not use `p` in any possible - calculation is given a derivative of `nothing` instead of zero. By default, this `nothing` - is caught in order to throw an informative error message about a potentially unintentional - misdefined function. However, if this was intentional, setting `allow_nothing=true` will - remove the error message. - -""" -struct ZygoteVJP <: VJPChoice - allow_nothing::Bool -end -ZygoteVJP(;allow_nothing=false) = ZygoteVJP(allow_nothing) - - -""" -EnzymeVJP <: VJPChoice - -Uses Enzyme.jl to compute vector-Jacobian products. Is the fastest VJP whenever applicable, -though Enzyme.jl currently has low coverage over the Julia programming language, for example -restricting the user's defined `f` function to not do things like require garbage collection -or calls to BLAS/LAPACK. However, mutation is supported, meaning that in-place `f` with -fully mutating non-allocating code will work with Enzyme (provided no high level calls to C -like BLAS/LAPACK are used) and this will be the most efficient adjoint implementation. - -## Constructor - -```julia -EnzymeVJP(compile=false) -``` -""" -struct EnzymeVJP <: VJPChoice end - -""" -TrackerVJP <: VJPChoice - -Uses Tracker.jl to compute the vector-Jacobian products. If `f` is in-place, -then it uses a array of structs formulation to do scalarized reverse mode, -while if `f` is out-of-place then it uses an array-based reverse mode. - -Not as efficient as `ReverseDiffVJP`, but supports GPUs when doing array-based -reverse mode. - -## Constructor - -```julia -TrackerVJP(;allow_nothing=false) -``` - -Keyword arguments: - -* `allow_nothing`: whether non-tracked values should be implicitly converted to zeros. In Tracker, - the derivative of a function with respect to `p` which does not use `p` in any possible - calculation is given an untracked return instead of zero. By default, this `nothing` Trackedness - is caught in order to throw an informative error message about a potentially unintentional - misdefined function. However, if this was intentional, setting `allow_nothing=true` will - remove the error message. -""" -struct TrackerVJP <: VJPChoice - allow_nothing::Bool -end -TrackerVJP(;allow_nothing=false) = TrackerVJP(allow_nothing) - -""" -ReverseDiffVJP{compile} <: VJPChoice - -Uses ReverseDiff.jl to compute the vector-Jacobian products. If `f` is in-place, -then it uses a array of structs formulation to do scalarized reverse mode, -while if `f` is out-of-place then it uses an array-based reverse mode. - -Usually the fastest when scalarized operations exist in the f function -(like in scientific machine learning applications like Universal Differential Equations) -and the boolean compilation is enabled (i.e. ReverseDiffVJP(true)), if EnzymeVJP fails on -a given choice of `f`. - -Does not support GPUs (CuArrays). - -## Constructor - -```julia -ReverseDiffVJP(compile=false) -``` - -## Keyword Arguments - -* `compile`: Whether to cache the compilation of the reverse tape. This heavily increases - the performance of the method but requires that the `f` function of the ODE/DAE/SDE/DDE - has no branching. -""" -struct ReverseDiffVJP{compile} <: VJPChoice - ReverseDiffVJP(compile=false) = new{compile}() -end - -@inline convert_tspan(::ForwardDiffSensitivity{CS,CTS}) where {CS,CTS} = CTS -@inline convert_tspan(::Any) = nothing -@inline alg_autodiff(alg::DiffEqBase.AbstractSensitivityAlgorithm{CS,AD,FDT}) where {CS,AD,FDT} = AD -@inline get_chunksize(alg::DiffEqBase.AbstractSensitivityAlgorithm{CS,AD,FDT}) where {CS,AD,FDT} = CS -@inline diff_type(alg::DiffEqBase.AbstractSensitivityAlgorithm{CS,AD,FDT}) where {CS,AD,FDT} = FDT -@inline function get_jacvec(alg::DiffEqBase.AbstractSensitivityAlgorithm) - alg.autojacvec isa Bool ? alg.autojacvec : true -end -@inline function get_jacmat(alg::DiffEqBase.AbstractSensitivityAlgorithm) - alg.autojacmat isa Bool ? alg.autojacmat : true -end -@inline ischeckpointing(alg::DiffEqBase.AbstractSensitivityAlgorithm, sol=nothing) = false -@inline ischeckpointing(alg::InterpolatingAdjoint) = alg.checkpointing -@inline ischeckpointing(alg::InterpolatingAdjoint, sol) = alg.checkpointing || !sol.dense -@inline ischeckpointing(alg::BacksolveAdjoint, sol=nothing) = alg.checkpointing - -@inline isnoisemixing(alg::DiffEqBase.AbstractSensitivityAlgorithm) = false -@inline isnoisemixing(alg::InterpolatingAdjoint) = alg.noisemixing -@inline isnoisemixing(alg::BacksolveAdjoint) = alg.noisemixing - -@inline compile_tape(vjp::ReverseDiffVJP{compile}) where compile = compile -@inline compile_tape(autojacvec::Bool) = false - -""" -ForwardDiffOverAdjoint{A} <: AbstractSecondOrderSensitivityAlgorithm{nothing,true,nothing} - -ForwardDiff.jl over a choice of `sensealg` method for the adjoint. - -## Constructor - -```julia -ForwardDiffOverAdjoint(sensealg) -``` - -## SciMLProblem Support - -This supports any SciMLProblem that the `sensealg` choice supports, provided the solver algorithm -is `SciMLBase.isautodifferentiable`. - -## References - -Hindmarsh, A. C. and Brown, P. N. and Grant, K. E. and Lee, S. L. and Serban, R. -and Shumaker, D. E. and Woodward, C. S., SUNDIALS: Suite of nonlinear and -differential/algebraic equation solvers, ACM Transactions on Mathematical -Software (TOMS), 31, pp:363–396 (2005) -""" -struct ForwardDiffOverAdjoint{A} <: AbstractSecondOrderSensitivityAlgorithm{nothing,true,nothing} - adjalg::A -end +function SensitivityAlg(args...; kwargs...) + @error("The SensitivtyAlg choice mechanism was completely overhauled. Please consult the local sensitivity documentation for more information") +end + +abstract type AbstractForwardSensitivityAlgorithm{CS, AD, FDT} <: + DiffEqBase.AbstractSensitivityAlgorithm{CS, AD, FDT} end +abstract type AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} <: + DiffEqBase.AbstractSensitivityAlgorithm{CS, AD, FDT} end +abstract type AbstractSecondOrderSensitivityAlgorithm{CS, AD, FDT} <: + DiffEqBase.AbstractSensitivityAlgorithm{CS, AD, FDT} end +abstract type AbstractShadowingSensitivityAlgorithm{CS, AD, FDT} <: + DiffEqBase.AbstractSensitivityAlgorithm{CS, AD, FDT} end + +""" +ForwardSensitivity{CS,AD,FDT} <: AbstractForwardSensitivityAlgorithm{CS,AD,FDT} + +An implementation of continuous forward sensitivity analysis for propagating +derivatives by solving the extended ODE. When used within adjoint differentiation +(i.e. via Zygote), this will cause forward differentiation of the `solve` call +within the reverse-mode automatic differentiation environment. + +## Constructor + +```julia +function ForwardSensitivity(; + chunk_size=0,autodiff=true, + diff_type=Val{:central}, + autojacvec=autodiff, + autojacmat=false) +``` + +## Keyword Arguments + +* `autodiff`: Use automatic differentiation in the internal sensitivity algorithm + computations. Default is `true`. +* `chunk_size`: Chunk size for forward mode differentiation if full Jacobians are + built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic + choice of chunk size. +* `autojacvec`: Calculate the Jacobian-vector product via automatic + differentiation with special seeding. +* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian + if the full Jacobian is required with `autodiff=false`. + +Further details: + +- If `autodiff=true` and `autojacvec=true`, then the one chunk `J*v` forward-mode + directional derivative calculation trick is used to compute the product without + constructing the Jacobian (via ForwardDiff.jl). +- If `autodiff=false` and `autojacvec=true`, then the numerical direction derivative + trick `(f(x+epsilon*v)-f(x))/epsilon` is used to compute `J*v` without constructing + the Jacobian. +- If `autodiff=true` and `autojacvec=false`, then the Jacobian is constructed via + chunked forward-mode automatic differentiation (via ForwardDiff.jl). +- If `autodiff=false` and `autojacvec=false`, then the Jacobian is constructed via + finite differences via FiniteDiff.jl. + +## SciMLProblem Support + +This `sensealg` only supports `ODEProblem`s without callbacks (events). +""" +struct ForwardSensitivity{CS, AD, FDT} <: AbstractForwardSensitivityAlgorithm{CS, AD, FDT} + autojacvec::Bool + autojacmat::Bool +end +Base.@pure function ForwardSensitivity(; + chunk_size = 0, autodiff = true, + diff_type = Val{:central}, + autojacvec = autodiff, + autojacmat = false) + autojacvec && autojacmat && + error("Choose either Jacobian matrix products or Jacobian vector products, + autojacmat and autojacvec cannot both be true") + ForwardSensitivity{chunk_size, autodiff, diff_type}(autojacvec, autojacmat) +end + +""" +ForwardDiffSensitivity{CS,CTS} <: AbstractForwardSensitivityAlgorithm{CS,Nothing,Nothing} + +An implementation of discrete forward sensitivity analysis through ForwardDiff.jl. +When used within adjoint differentiation (i.e. via Zygote), this will cause forward +differentiation of the `solve` call within the reverse-mode automatic differentiation +environment. + +## Constructor + +```julia +ForwardDiffSensitivity(;chunk_size=0,convert_tspan=nothing) +``` + +## Keyword Arguments + +* `chunk_size`: the chunk size used by ForwardDiff for computing the Jacobian, i.e. the + number of simultaneous columns computed. +* `convert_tspan`: whether to convert time to also be `Dual` valued. By default this is + `nothing` which will only convert if callbacks are found. Conversion is required in order + to accurately differentiate callbacks (hybrid equations). + +## SciMLProblem Support + +This `sensealg` supports any `SciMLProblem`s, provided that the solver algorithms is +`SciMLBase.isautodifferentiable`. Note that `ForwardDiffSensitivity` can +accurately differentiate code with callbacks only when `convert_tspan=true`. +""" +struct ForwardDiffSensitivity{CS, CTS} <: + AbstractForwardSensitivityAlgorithm{CS, Nothing, Nothing} end +Base.@pure function ForwardDiffSensitivity(; chunk_size = 0, convert_tspan = nothing) + ForwardDiffSensitivity{chunk_size, convert_tspan}() +end + +""" +BacksolveAdjoint{CS,AD,FDT,VJP} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} + +An implementation of adjoint sensitivity analysis using a backwards solution of the ODE. +By default this algorithm will use the values from the forward pass to perturb the +backwards solution to the correct spot, allowing reduced memory (O(1) memory). Checkpointing +stabilization is included for additional numerical stability over the naive implementation. + +## Constructor + +```julia +BacksolveAdjoint(;chunk_size=0,autodiff=true, + diff_type=Val{:central}, + autojacvec=nothing, + checkpointing=true, noisemixing=false) +``` + +## Keyword Arguments + +* `autodiff`: Use automatic differentiation for constructing the Jacobian + if the Jacobian needs to be constructed. Defaults to `true`. +* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are + built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic + choice of chunk size. +* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian + if the full Jacobian is required with `autodiff=false`. +* `autojacvec`: Calculate the vector-Jacobian product (`J'*v`) via automatic + differentiation with special seeding. The default is `true`. The total set + of choices are: + - `false`: the Jacobian is constructed via FiniteDiff.jl + - `true`: the Jacobian is constructed via ForwardDiff.jl + - `TrackerVJP`: Uses Tracker.jl for the vjp. + - `ZygoteVJP`: Uses Zygote.jl for the vjp. + - `EnzymeVJP`: Uses Enzyme.jl for the vjp. + - `ReverseDiffVJP(compile=false)`: Uses ReverseDiff.jl for the vjp. `compile` + is a boolean for whether to precompile the tape, which should only be done + if there are no branches (`if` or `while` statements) in the `f` function. +* `checkpointing`: whether checkpointing is enabled for the reverse pass. Defaults + to `true`. +* `noisemixing`: Handle noise processes that are not of the form `du[i] = f(u[i])`. + For example, to compute the sensitivities of an SDE with diagonal diffusion + ```julia + function g_mixing!(du,u,p,t) + du[1] = p[3]*u[1] + p[4]*u[2] + du[2] = p[3]*u[1] + p[4]*u[2] + nothing + end + ``` + correctly, `noisemixing=true` must be enabled. The default is `false`. + +For more details on the vjp choices, please consult the sensitivity algorithms +documentation page or the docstrings of the vjp types. + +## Applicability of Backsolve and Caution + +When `BacksolveAdjoint` is applicable, it is a fast method and requires the least memory. +However, one must be cautious because not all ODEs are stable under backwards integration +by the majority of ODE solvers. An example of such an equation is the Lorenz equation. +Notice that if one solves the Lorenz equation forward and then in reverse with any +adaptive time step and non-reversible integrator, then the backwards solution diverges +from the forward solution. As a quick demonstration: + +```julia +using Sundials +function lorenz(du,u,p,t) + du[1] = 10.0*(u[2]-u[1]) + du[2] = u[1]*(28.0-u[3]) - u[2] + du[3] = u[1]*u[2] - (8/3)*u[3] +end +u0 = [1.0;0.0;0.0] +tspan = (0.0,100.0) +prob = ODEProblem(lorenz,u0,tspan) +sol = solve(prob,Tsit5(),reltol=1e-12,abstol=1e-12) +prob2 = ODEProblem(lorenz,sol[end],(100.0,0.0)) +sol = solve(prob,Tsit5(),reltol=1e-12,abstol=1e-12) +@show sol[end]-u0 #[-3.22091, -1.49394, 21.3435] +``` + +Thus one should check the stability of the backsolve on their type of problem before +enabling this method. Additionally, using checkpointing with backsolve can be a +low memory way to stabilize it. + +For more details on this topic, see +[Stiff Neural Ordinary Differential Equations](https://aip.scitation.org/doi/10.1063/5.0060697). + +## Checkpointing + +To improve the numerical stability of the reverse pass, `BacksolveAdjoint` includes a checkpointing +feature. If `sol.u` is a time series, then whenever a time `sol.t` is hit while reversing, a callback +will replace the reversing ODE portion with `sol.u[i]`. This nudges the solution back onto the appropriate +trajectory and reduces the numerical caused by drift. + +## SciMLProblem Support + +This `sensealg` only supports `ODEProblem`s, `SDEProblem`s, and `RODEProblem`s. This `sensealg` supports +callback functions (events). + +## References + +ODE: + Rackauckas, C. and Ma, Y. and Martensen, J. and Warner, C. and Zubov, K. and Supekar, + R. and Skinner, D. and Ramadhana, A. and Edelman, A., Universal Differential Equations + for Scientific Machine Learning, arXiv:2001.04385 + + Hindmarsh, A. C. and Brown, P. N. and Grant, K. E. and Lee, S. L. and Serban, R. + and Shumaker, D. E. and Woodward, C. S., SUNDIALS: Suite of nonlinear and + differential/algebraic equation solvers, ACM Transactions on Mathematical + Software (TOMS), 31, pp:363–396 (2005) + + Chen, R.T.Q. and Rubanova, Y. and Bettencourt, J. and Duvenaud, D. K., + Neural ordinary differential equations. In Advances in neural information processing + systems, pp. 6571–6583 (2018) + + Pontryagin, L. S. and Mishchenko, E.F. and Boltyanskii, V.G. and Gamkrelidze, R.V. + The mathematical theory of optimal processes. Routledge, (1962) + + Rackauckas, C. and Ma, Y. and Dixit, V. and Guo, X. and Innes, M. and Revels, J. + and Nyberg, J. and Ivaturi, V., A comparison of automatic differentiation and + continuous sensitivity analysis for derivatives of differential equation solutions, + arXiv:1812.01892 + +DAE: + Cao, Y. and Li, S. and Petzold, L. and Serban, R., Adjoint sensitivity analysis + for differential-algebraic equations: The adjoint DAE system and its numerical + solution, SIAM journal on scientific computing 24 pp: 1076-1089 (2003) + +SDE: + Gobet, E. and Munos, R., Sensitivity Analysis Using Ito-Malliavin Calculus and + Martingales, and Application to Stochastic Optimal Control, + SIAM Journal on control and optimization, 43, pp. 1676-1713 (2005) + + Li, X. and Wong, T.-K. L.and Chen, R. T. Q. and Duvenaud, D., + Scalable Gradients for Stochastic Differential Equations, + PMLR 108, pp. 3870-3882 (2020), http://proceedings.mlr.press/v108/li20i.html +""" +struct BacksolveAdjoint{CS, AD, FDT, VJP} <: + AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} + autojacvec::VJP + checkpointing::Bool + noisemixing::Bool +end +Base.@pure function BacksolveAdjoint(; chunk_size = 0, autodiff = true, + diff_type = Val{:central}, + autojacvec = nothing, + checkpointing = true, noisemixing = false) + BacksolveAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec)}(autojacvec, + checkpointing, + noisemixing) +end +function setvjp(sensealg::BacksolveAdjoint{CS, AD, FDT, Nothing}, vjp) where {CS, AD, FDT} + BacksolveAdjoint{CS, AD, FDT, typeof(vjp)}(vjp, sensealg.checkpointing, + sensealg.noisemixing) +end + +""" +InterpolatingAdjoint{CS,AD,FDT,VJP} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} + +An implementation of adjoint sensitivity analysis which uses the interpolation of +the forward solution for the reverse solve vector-Jacobian products. By +default it requires a dense solution of the forward pass and will internally +ignore saving arguments during the gradient calculation. When checkpointing is +enabled it will only require the memory to interpolate between checkpoints. + +## Constructor + +```julia +function InterpolatingAdjoint(;chunk_size=0,autodiff=true, + diff_type=Val{:central}, + autojacvec=nothing, + checkpointing=false, noisemixing=false) +``` + +## Keyword Arguments + +* `autodiff`: Use automatic differentiation for constructing the Jacobian + if the Jacobian needs to be constructed. Defaults to `true`. +* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are + built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic + choice of chunk size. +* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian + if the full Jacobian is required with `autodiff=false`. +* `autojacvec`: Calculate the vector-Jacobian product (`J'*v`) via automatic + differentiation with special seeding. The default is `true`. The total set + of choices are: + - `false`: the Jacobian is constructed via FiniteDiff.jl + - `true`: the Jacobian is constructed via ForwardDiff.jl + - `TrackerVJP`: Uses Tracker.jl for the vjp. + - `ZygoteVJP`: Uses Zygote.jl for the vjp. + - `EnzymeVJP`: Uses Enzyme.jl for the vjp. + - `ReverseDiffVJP(compile=false)`: Uses ReverseDiff.jl for the vjp. `compile` + is a boolean for whether to precompile the tape, which should only be done + if there are no branches (`if` or `while` statements) in the `f` function. +* `checkpointing`: whether checkpointing is enabled for the reverse pass. Defaults + to `true`. +* `noisemixing`: Handle noise processes that are not of the form `du[i] = f(u[i])`. + For example, to compute the sensitivities of an SDE with diagonal diffusion + ```julia + function g_mixing!(du,u,p,t) + du[1] = p[3]*u[1] + p[4]*u[2] + du[2] = p[3]*u[1] + p[4]*u[2] + nothing + end + ``` + correctly, `noisemixing=true` must be enabled. The default is `false`. + +For more details on the vjp choices, please consult the sensitivity algorithms +documentation page or the docstrings of the vjp types. + +## Checkpointing + +To reduce the memory usage of the reverse pass, `InterpolatingAdjoint` includes a checkpointing +feature. If `sol` is `dense`, checkpointing is ignored and the continuous solution is used for +calculating `u(t)` at arbitrary time points. If `checkpointing=true` and `sol` is not `dense`, +then dense intervals between `sol.t[i]` and `sol.t[i+1]` are reconstructed on-demand for calculating +`u(t)` at arbitrary time points. This reduces the total memory requirement to only the cost of +holding the dense solution over the largest time interval (in terms of number of required steps). +The total compute cost is no more than double the original forward compute cost. + +## SciMLProblem Support + +This `sensealg` only supports `ODEProblem`s, `SDEProblem`s, and `RODEProblem`s. This `sensealg` +supports callbacks (events). + +## References + + Rackauckas, C. and Ma, Y. and Martensen, J. and Warner, C. and Zubov, K. and Supekar, + R. and Skinner, D. and Ramadhana, A. and Edelman, A., Universal Differential Equations + for Scientific Machine Learning, arXiv:2001.04385 + + Hindmarsh, A. C. and Brown, P. N. and Grant, K. E. and Lee, S. L. and Serban, R. + and Shumaker, D. E. and Woodward, C. S., SUNDIALS: Suite of nonlinear and + differential/algebraic equation solvers, ACM Transactions on Mathematical + Software (TOMS), 31, pp:363–396 (2005) + + Rackauckas, C. and Ma, Y. and Dixit, V. and Guo, X. and Innes, M. and Revels, J. + and Nyberg, J. and Ivaturi, V., A comparison of automatic differentiation and + continuous sensitivity analysis for derivatives of differential equation solutions, + arXiv:1812.01892 +""" +struct InterpolatingAdjoint{CS, AD, FDT, VJP} <: + AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} + autojacvec::VJP + checkpointing::Bool + noisemixing::Bool +end +Base.@pure function InterpolatingAdjoint(; chunk_size = 0, autodiff = true, + diff_type = Val{:central}, + autojacvec = nothing, + checkpointing = false, noisemixing = false) + InterpolatingAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec)}(autojacvec, + checkpointing, + noisemixing) +end +function setvjp(sensealg::InterpolatingAdjoint{CS, AD, FDT, Nothing}, + vjp) where {CS, AD, FDT} + InterpolatingAdjoint{CS, AD, FDT, typeof(vjp)}(vjp, sensealg.checkpointing, + sensealg.noisemixing) +end + +""" +QuadratureAdjoint{CS,AD,FDT,VJP} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} + +An implementation of adjoint sensitivity analysis which develops a full +continuous solution of the reverse solve in order to perform a post-ODE +quadrature. This method requires the the dense solution and will ignore +saving arguments during the gradient calculation. The tolerances in the +constructor control the inner quadrature. The inner quadrature uses a +ReverseDiff vjp if autojacvec, and `compile=false` by default but can +compile the tape under the same circumstances as `ReverseDiffVJP`. + +This method is O(n^3 + p) for stiff / implicit equations (as opposed to the +O((n+p)^3) scaling of BacksolveAdjoint and InterpolatingAdjoint), and thus +is much more compute efficient. However, it requires holding a dense reverse +pass and is thus memory intensive. + +## Constructor + +```julia +function QuadratureAdjoint(;chunk_size=0,autodiff=true, + diff_type=Val{:central}, + autojacvec=nothing,abstol=1e-6, + reltol=1e-3,compile=false) +``` + +## Keyword Arguments + +* `autodiff`: Use automatic differentiation for constructing the Jacobian + if the Jacobian needs to be constructed. Defaults to `true`. +* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are + built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic + choice of chunk size. +* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian + if the full Jacobian is required with `autodiff=false`. +* `autojacvec`: Calculate the vector-Jacobian product (`J'*v`) via automatic + differentiation with special seeding. The default is `true`. The total set + of choices are: + - `false`: the Jacobian is constructed via FiniteDiff.jl + - `true`: the Jacobian is constructed via ForwardDiff.jl + - `TrackerVJP`: Uses Tracker.jl for the vjp. + - `ZygoteVJP`: Uses Zygote.jl for the vjp. + - `EnzymeVJP`: Uses Enzyme.jl for the vjp. + - `ReverseDiffVJP(compile=false)`: Uses ReverseDiff.jl for the vjp. `compile` + is a boolean for whether to precompile the tape, which should only be done + if there are no branches (`if` or `while` statements) in the `f` function. +* `abstol`: absolute tolerance for the quadrature calculation +* `reltol`: relative tolerance for the quadrature calculation +* `compile`: whether to compile the vjp calculation for the integrand calculation. + See `ReverseDiffVJP` for more details. + +For more details on the vjp choices, please consult the sensitivity algorithms +documentation page or the docstrings of the vjp types. + +## SciMLProblem Support + +This `sensealg` only supports `ODEProblem`s. This `sensealg` supports events (callbacks). + +## References + + Rackauckas, C. and Ma, Y. and Martensen, J. and Warner, C. and Zubov, K. and Supekar, + R. and Skinner, D. and Ramadhana, A. and Edelman, A., Universal Differential Equations + for Scientific Machine Learning, arXiv:2001.04385 + + Hindmarsh, A. C. and Brown, P. N. and Grant, K. E. and Lee, S. L. and Serban, R. + and Shumaker, D. E. and Woodward, C. S., SUNDIALS: Suite of nonlinear and + differential/algebraic equation solvers, ACM Transactions on Mathematical + Software (TOMS), 31, pp:363–396 (2005) + + Rackauckas, C. and Ma, Y. and Dixit, V. and Guo, X. and Innes, M. and Revels, J. + and Nyberg, J. and Ivaturi, V., A comparison of automatic differentiation and + continuous sensitivity analysis for derivatives of differential equation solutions, + arXiv:1812.01892 + + Kim, S., Ji, W., Deng, S., Ma, Y., & Rackauckas, C. (2021). Stiff neural ordinary + differential equations. Chaos: An Interdisciplinary Journal of Nonlinear Science, 31(9), 093122. +""" +struct QuadratureAdjoint{CS, AD, FDT, VJP} <: + AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} + autojacvec::VJP + abstol::Float64 + reltol::Float64 +end +Base.@pure function QuadratureAdjoint(; chunk_size = 0, autodiff = true, + diff_type = Val{:central}, + autojacvec = nothing, abstol = 1e-6, + reltol = 1e-3) + QuadratureAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec)}(autojacvec, + abstol, reltol) +end +function setvjp(sensealg::QuadratureAdjoint{CS, AD, FDT, Nothing}, vjp) where {CS, AD, FDT} + QuadratureAdjoint{CS, AD, FDT, typeof(vjp)}(vjp, sensealg.abstol, + sensealg.reltol) +end + +""" +TrackerAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} + +An implementation of discrete adjoint sensitivity analysis +using the Tracker.jl tracing-based AD. Supports in-place functions through +an Array of Structs formulation, and supports out of place through struct of +arrays. + +## Constructor + +```julia +TrackerAdjoint() +``` + +## SciMLProblem Support + +This `sensealg` supports any `DEProblem` if the algorithm is `SciMLBase.isautodifferentiable` +Compatible with a limited subset of `AbstractArray` types for `u0`, including `CuArrays`. + +!!! warn + + TrackerAdjoint is incompatible with Stiff ODE solvers using forward-mode automatic + differentiation for the Jacobians. Thus for example, `TRBDF2()` will error. Instead, + use `autodiff=false`, i.e. `TRBDF2(autodiff=false)`. This will only remove the + forward-mode automatic differentiation of the Jacobian construction, not the reverse-mode + AD usage, and thus performance will still be nearly the same, though Jacobian accuracy + may suffer which could cause more steps to be required. +""" +struct TrackerAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} end + +""" +ReverseDiffAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} + +An implementation of discrete adjoint sensitivity analysis using the ReverseDiff.jl +tracing-based AD. Supports in-place functions through an Array of Structs formulation, +and supports out of place through struct of arrays. + +## Constructor + +```julia +ReverseDiffAdjoint() +``` + +## SciMLProblem Support + +This `sensealg` supports any `DEProblem` if the algorithm is `SciMLBase.isautodifferentiable`. +Requires that the state variables are CPU-based `Array` types. +""" +struct ReverseDiffAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} end + +""" +ZygoteAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} + +An implementation of discrete adjoint sensitivity analysis +using the Zygote.jl source-to-source AD directly on the differential equation +solver. + +## Constructor + +```julia +ZygoteAdjoint() +``` + +## SciMLProblem Support + +Currently fails on almost every solver. +""" +struct ZygoteAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} end + +""" +ForwardLSS{CS,AD,FDT,RType,gType} <: AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} + +An implementation of the discrete, forward-mode +[least squares shadowing](https://arxiv.org/abs/1204.0159) (LSS) method. LSS replaces +the ill-conditioned initial value probem (`ODEProblem`) for chaotic systems by a +well-conditioned least-squares problem. This allows for computing sensitivities of +long-time averaged quantities with respect to the parameters of the `ODEProblem`. The +computational cost of LSS scales as (number of states x number of time steps). Converges +to the correct sensitivity at a rate of `T^(-1/2)`, where `T` is the time of the trajectory. +See `NILSS()` and `NILSAS()` for a more efficient non-intrusive formulation. + +## Constructor + +```julia +ForwardLSS(; + chunk_size=0,autodiff=true, + diff_type=Val{:central}, + LSSregularizer=TimeDilation(10.0,0.0,0.0), + g=nothing) +``` + +## Keyword Arguments + +* `autodiff`: Use automatic differentiation for constructing the Jacobian + if the Jacobian needs to be constructed. Defaults to `true`. +* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are + built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic + choice of chunk size. +* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian + if the full Jacobian is required with `autodiff=false`. +* `LSSregularizer`: Using `LSSregularizer`, one can choose between three different + regularization routines. The default choice is `TimeDilation(10.0,0.0,0.0)`. + - `CosWindowing()`: cos windowing of the time grid, i.e. the time grid (saved + time steps) is transformed using a cosine. + - `Cos2Windowing()`: cos^2 windowing of the time grid. + - `TimeDilation(alpha::Number,t0skip::Number,t1skip::Number)`: Corresponds to + a time dilation. `alpha` controls the weight. `t0skip` and `t1skip` indicate + the times truncated at the beginnning and end of the trajectory, respectively. +* `g`: instantaneous objective function of the long-time averaged objective. + +## SciMLProblem Support + +This `sensealg` only supports `ODEProblem`s. This `sensealg` does not support +events (callbacks). This `sensealg` assumes that the objective is a long-time averaged +quantity and ergodic, i.e. the time evolution of the system behaves qualitatively the +same over infinite time independent of the specified initial conditions, such that only +the sensitivity with respect to the parameters is of interest. + +## References + +Wang, Q., Hu, R., and Blonigan, P. Least squares shadowing sensitivity analysis of +chaotic limit cycle oscillations. Journal of Computational Physics, 267, 210-224 (2014). + +Wang, Q., Convergence of the Least Squares Shadowing Method for Computing Derivative of Ergodic +Averages, SIAM Journal on Numerical Analysis, 52, 156–170 (2014). + +Blonigan, P., Gomez, S., Wang, Q., Least Squares Shadowing for sensitivity analysis of turbulent +fluid flows, in: 52nd Aerospace Sciences Meeting, 1–24 (2014). +""" +struct ForwardLSS{CS, AD, FDT, RType, gType} <: + AbstractShadowingSensitivityAlgorithm{CS, AD, FDT} + LSSregularizer::RType + g::gType +end +Base.@pure function ForwardLSS(; + chunk_size = 0, autodiff = true, + diff_type = Val{:central}, + LSSregularizer = TimeDilation(10.0, 0.0, 0.0), + g = nothing) + ForwardLSS{chunk_size, autodiff, diff_type, typeof(LSSregularizer), typeof(g)}(LSSregularizer, + g) +end + +""" +AdjointLSS{CS,AD,FDT,RType,gType} <: AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} + +An implementation of the discrete, adjoint-mode +[least square shadowing](https://arxiv.org/abs/1204.0159) method. LSS replaces +the ill-conditioned initial value probem (`ODEProblem`) for chaotic systems by a +well-conditioned least-squares problem. This allows for computing sensitivities of +long-time averaged quantities with respect to the parameters of the `ODEProblem`. The +computational cost of LSS scales as (number of states x number of time steps). Converges +to the correct sensitivity at a rate of `T^(-1/2)`, where `T` is the time of the trajectory. +See `NILSS()` and `NILSAS()` for a more efficient non-intrusive formulation. + +## Constructor + +```julia +AdjointLSS(; + chunk_size=0,autodiff=true, + diff_type=Val{:central}, + LSSRegularizer=TimeDilation(10.0,0.0,0.0), + g=nothing) +``` + +## Keyword Arguments + +* `autodiff`: Use automatic differentiation for constructing the Jacobian + if the Jacobian needs to be constructed. Defaults to `true`. +* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are + built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic + choice of chunk size. +* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian + if the full Jacobian is required with `autodiff=false`. +* `LSSregularizer`: Using `LSSregularizer`, one can choose between different + regularization routines. The default choice is `TimeDilation(10.0,0.0,0.0)`. + - `TimeDilation(alpha::Number,t0skip::Number,t1skip::Number)`: Corresponds to + a time dilation. `alpha` controls the weight. `t0skip` and `t1skip` indicate + the times truncated at the beginnning and end of the trajectory, respectively. + The default value for `t0skip` and `t1skip` is `zero(alpha)`. +* `g`: instantaneous objective function of the long-time averaged objective. + +## SciMLProblem Support + +This `sensealg` only supports `ODEProblem`s. This `sensealg` does not support +events (callbacks). This `sensealg` assumes that the objective is a long-time averaged +quantity and ergodic, i.e. the time evolution of the system behaves qualitatively the +same over infinite time independent of the specified initial conditions, such that only +the sensitivity with respect to the parameters is of interest. + +## References + +Wang, Q., Hu, R., and Blonigan, P. Least squares shadowing sensitivity analysis of +chaotic limit cycle oscillations. Journal of Computational Physics, 267, 210-224 (2014). +""" +struct AdjointLSS{CS, AD, FDT, RType, gType} <: + AbstractShadowingSensitivityAlgorithm{CS, AD, FDT} + LSSregularizer::RType + g::gType +end +Base.@pure function AdjointLSS(; + chunk_size = 0, autodiff = true, + diff_type = Val{:central}, + LSSregularizer = TimeDilation(10.0, 0.0, 0.0), + g = nothing) + AdjointLSS{chunk_size, autodiff, diff_type, typeof(LSSregularizer), typeof(g)}(LSSregularizer, + g) +end + +abstract type AbstractLSSregularizer end +abstract type AbstractCosWindowing <: AbstractLSSregularizer end +struct CosWindowing <: AbstractCosWindowing end +struct Cos2Windowing <: AbstractCosWindowing end + +""" +TimeDilation{T1<:Number} <: AbstractLSSregularizer + +A regularization method for `LSS`. See `?LSS` for +additional information and other methods. + +## Constructor + +```julia +TimeDilation(alpha; + t0skip=zero(alpha), + t1skip=zero(alpha)) +``` +""" +struct TimeDilation{T1 <: Number} <: AbstractLSSregularizer + alpha::T1 # alpha: weight of the time dilation term in LSS. + t0skip::T1 + t1skip::T1 +end +function TimeDilation(alpha, t0skip = zero(alpha), t1skip = zero(alpha)) + TimeDilation{typeof(alpha)}(alpha, t0skip, t1skip) +end +""" +struct NILSS{CS,AD,FDT,RNG,nType,gType} <: AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} + +An implementation of the forward-mode, continuous +[non-intrusive least squares shadowing](https://arxiv.org/abs/1611.00880) method. `NILSS` +allows for computing sensitivities of long-time averaged quantities with respect to the +parameters of an `ODEProblem` by constraining the computation to the unstable subspace. +`NILSS` employs the continuous-time `ForwardSensitivity` method as tangent solver. To +avoid an exponential blow-up of the (homogenous and inhomogenous) tangent solutions, +the trajectory should be divided into sufficiently small segments, where the tangent solutions +are rescaled on the interfaces. The computational and memory cost of NILSS scale with +the number of unstable (positive) Lyapunov exponents (instead of the number of states as +in the LSS method). `NILSS` avoids the explicit construction of the Jacobian at each time +step and thus should generally be preferred (for large system sizes) over `ForwardLSS`. + +## Constructor + +```julia +NILSS(nseg, nstep; nus = nothing, + rng = Xorshifts.Xoroshiro128Plus(rand(UInt64)), + chunk_size=0,autodiff=true, + diff_type=Val{:central}, + autojacvec=autodiff, + g=nothing) +``` + +## Arguments + +* `nseg`: Number of segments on full time interval on the attractor. +* `nstep`: number of steps on each segment. + +## Keyword Arguments + +* `nus`: Dimension of the unstable subspace. Default is `nothing`. `nus` must be + smaller or equal to the state dimension (`length(u0)`). With the default choice, + `nus = length(u0) - 1` will be set at compile time. +* `rng`: (Pseudo) random number generator. Used for initializing the homogenous + tangent states (`w`). Default is `Xorshifts.Xoroshiro128Plus(rand(UInt64))`. +* `autodiff`: Use automatic differentiation in the internal sensitivity algorithm + computations. Default is `true`. +* `chunk_size`: Chunk size for forward mode differentiation if full Jacobians are + built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic + choice of chunk size. +* `autojacvec`: Calculate the Jacobian-vector product via automatic + differentiation with special seeding. +* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian + if the full Jacobian is required with `autodiff=false`. +* `g`: instantaneous objective function of the long-time averaged objective. + +## SciMLProblem Support + +This `sensealg` only supports `ODEProblem`s. This `sensealg` does not support +events (callbacks). This `sensealg` assumes that the objective is a long-time averaged +quantity and ergodic, i.e. the time evolution of the system behaves qualitatively the +same over infinite time independent of the specified initial conditions, such that only +the sensitivity with respect to the parameters is of interest. + +## References +Ni, A., Blonigan, P. J., Chater, M., Wang, Q., Zhang, Z., Sensitivity analy- +sis on chaotic dynamical system by Non-Intrusive Least Square Shadowing +(NI-LSS), in: 46th AIAA Fluid Dynamics Conference, AIAA AVIATION Forum (AIAA 2016-4399), +American Institute of Aeronautics and Astronautics, 1–16 (2016). + +Ni, A., and Wang, Q. Sensitivity analysis on chaotic dynamical systems by Non-Intrusive +Least Squares Shadowing (NILSS). Journal of Computational Physics 347, 56-77 (2017). +""" +struct NILSS{CS, AD, FDT, RNG, nType, gType} <: + AbstractShadowingSensitivityAlgorithm{CS, AD, FDT} + rng::RNG + nseg::Int + nstep::Int + nus::nType + autojacvec::Bool + g::gType +end +Base.@pure function NILSS(nseg, nstep; nus = nothing, + rng = Xorshifts.Xoroshiro128Plus(rand(UInt64)), + chunk_size = 0, autodiff = true, + diff_type = Val{:central}, + autojacvec = autodiff, + g = nothing) + NILSS{chunk_size, autodiff, diff_type, typeof(rng), typeof(nus), typeof(g)}(rng, nseg, + nstep, nus, + autojacvec, + g) +end + +""" +NILSAS{CS,AD,FDT,RNG,SENSE,gType} <: AbstractShadowingSensitivityAlgorithm{CS,AD,FDT} + +An implementation of the adjoint-mode, continuous +[non-intrusive adjoint least squares shadowing](https://arxiv.org/abs/1801.08674) method. +`NILSAS` allows for computing sensitivities of long-time averaged quantities with respect +to the parameters of an `ODEProblem` by constraining the computation to the unstable subspace. +`NILSAS` employs SciMLSensitivity.jl's continuous adjoint sensitivity methods on each segment +to compute (homogenous and inhomogenous) adjoint solutions. To avoid an exponential blow-up +of the adjoint solutions, the trajectory should be divided into sufficiently small segments, +where the adjoint solutions are rescaled on the interfaces. The computational and memory cost +of NILSAS scale with the number of unstable, adjoint Lyapunov exponents (instead of the number +of states as in the LSS method). `NILSAS` avoids the explicit construction of the Jacobian at +each time step and thus should generally be preferred (for large system sizes) over `AdjointLSS`. +`NILSAS` is favourable over `NILSS` for many parameters because NILSAS computes the gradient +with respect to multiple parameters with negligible additional cost. + +## Constructor + +```julia +NILSAS(nseg, nstep, M=nothing; rng = Xorshifts.Xoroshiro128Plus(rand(UInt64)), + adjoint_sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP()), + chunk_size=0,autodiff=true, + diff_type=Val{:central}, + g=nothing + ) +``` + +## Arguments + +* `nseg`: Number of segments on full time interval on the attractor. +* `nstep`: number of steps on each segment. +* `M`: number of homogenous adjoint solutions. This number must be bigger or equal + than the number of (positive, adjoint) Lyapunov exponents. Default is `nothing`. + +## Keyword Arguments + +* `rng`: (Pseudo) random number generator. Used for initializing the terminate + conditions of the homogenous adjoint states (`w`). Default is `Xorshifts.Xoroshiro128Plus(rand(UInt64))`. +* `adjoint_sensealg`: Continuous adjoint sensitivity method to compute homogenous + and inhomogenous adjoint solutions on each segment. Default is `BacksolveAdjoint(autojacvec=ReverseDiffVJP())`. + * `autojacvec`: Calculate the vector-Jacobian product (`J'*v`) via automatic + differentiation with special seeding. The default is `true`. The total set + of choices are: + - `false`: the Jacobian is constructed via FiniteDiff.jl + - `true`: the Jacobian is constructed via ForwardDiff.jl + - `TrackerVJP`: Uses Tracker.jl for the vjp. + - `ZygoteVJP`: Uses Zygote.jl for the vjp. + - `EnzymeVJP`: Uses Enzyme.jl for the vjp. + - `ReverseDiffVJP(compile=false)`: Uses ReverseDiff.jl for the vjp. `compile` + is a boolean for whether to precompile the tape, which should only be done + if there are no branches (`if` or `while` statements) in the `f` function. +* `autodiff`: Use automatic differentiation for constructing the Jacobian + if the Jacobian needs to be constructed. Defaults to `true`. +* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are + built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic + choice of chunk size. +* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian + if the full Jacobian is required with `autodiff=false`. +* `g`: instantaneous objective function of the long-time averaged objective. + +## SciMLProblem Support + +This `sensealg` only supports `ODEProblem`s. This `sensealg` does not support +events (callbacks). This `sensealg` assumes that the objective is a long-time averaged +quantity and ergodic, i.e. the time evolution of the system behaves qualitatively the +same over infinite time independent of the specified initial conditions, such that only +the sensitivity with respect to the parameters is of interest. + +## References + +Ni, A., and Talnikar, C., Adjoint sensitivity analysis on chaotic dynamical systems +by Non-Intrusive Least Squares Adjoint Shadowing (NILSAS). Journal of Computational +Physics 395, 690-709 (2019). +""" +struct NILSAS{CS, AD, FDT, RNG, SENSE, gType} <: + AbstractShadowingSensitivityAlgorithm{CS, AD, FDT} + rng::RNG + adjoint_sensealg::SENSE + M::Int + nseg::Int + nstep::Int + g::gType +end +Base.@pure function NILSAS(nseg, nstep, M = nothing; + rng = Xorshifts.Xoroshiro128Plus(rand(UInt64)), + adjoint_sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP()), + chunk_size = 0, autodiff = true, + diff_type = Val{:central}, + g = nothing) + + # integer dimension of the unstable subspace + M === nothing && + error("Please provide an `M` with `M >= nus + 1`, where nus is the number of unstable covariant Lyapunov vectors.") + + NILSAS{chunk_size, autodiff, diff_type, typeof(rng), typeof(adjoint_sensealg), typeof(g) + }(rng, adjoint_sensealg, M, + nseg, nstep, g) +end + +""" +SteadyStateAdjoint{CS,AD,FDT,VJP,LS} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} + +An implementation of the adjoint differentiation of a nonlinear solve. Uses the +implicit function theorem to directly compute the derivative of the solution to +``f(u,p) = 0`` with respect to `p`. + +## Constructor + +```julia +SteadyStateAdjoint(;chunk_size = 0, autodiff = true, + diff_type = Val{:central}, + autojacvec = autodiff, linsolve = nothing) +``` + +## Keyword Arguments + +* `autodiff`: Use automatic differentiation for constructing the Jacobian + if the Jacobian needs to be constructed. Defaults to `true`. +* `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are + built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic + choice of chunk size. +* `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian + if the full Jacobian is required with `autodiff=false`. +* `autojacvec`: Calculate the vector-Jacobian product (`J'*v`) via automatic + differentiation with special seeding. The default is `nothing`. The total set + of choices are: + - `false`: the Jacobian is constructed via FiniteDiff.jl + - `true`: the Jacobian is constructed via ForwardDiff.jl + - `TrackerVJP`: Uses Tracker.jl for the vjp. + - `ZygoteVJP`: Uses Zygote.jl for the vjp. + - `EnzymeVJP`: Uses Enzyme.jl for the vjp. + - `ReverseDiffVJP(compile=false)`: Uses ReverseDiff.jl for the vjp. `compile` + is a boolean for whether to precompile the tape, which should only be done + if there are no branches (`if` or `while` statements) in the `f` function. +* `linsolve`: the linear solver used in the adjoint solve. Defaults to `nothing`, + which uses a polyalgorithm to attempt to automatically choose an efficient + algorithm. + +For more details on the vjp choices, please consult the sensitivity algorithms +documentation page or the docstrings of the vjp types. + +## References + +Johnson, S. G., Notes on Adjoint Methods for 18.336, Online at +http://math.mit.edu/stevenj/18.336/adjoint.pdf (2007) +""" +struct SteadyStateAdjoint{CS, AD, FDT, VJP, LS} <: + AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} + autojacvec::VJP + linsolve::LS +end + +Base.@pure function SteadyStateAdjoint(; chunk_size = 0, autodiff = true, + diff_type = Val{:central}, + autojacvec = nothing, linsolve = nothing) + SteadyStateAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec), typeof(linsolve) + }(autojacvec, linsolve) +end +function setvjp(sensealg::SteadyStateAdjoint{CS, AD, FDT, LS}, vjp) where {CS, AD, FDT, LS} + SteadyStateAdjoint{CS, AD, FDT, typeof(vjp), LS}(vjp, sensealg.linsolve) +end + +abstract type VJPChoice end + +""" +ZygoteVJP <: VJPChoice + +Uses Zygote.jl to compute vector-Jacobian products. Tends to be the fastest VJP method if the +ODE/DAE/SDE/DDE is written with mostly vectorized functions (like neural networks and other +layers from Flux.jl) and the `f` functions is given out-of-place. If the `f` function is +in-place, then `Zygote.Buffer` arrays are used internally which can greatly reduce the +performance of the VJP method. + +## Constructor + +```julia +ZygoteVJP(;allow_nothing=false) +``` + +Keyword arguments: + +* `allow_nothing`: whether `nothing`s should be implicitly converted to zeros. In Zygote, + the derivative of a function with respect to `p` which does not use `p` in any possible + calculation is given a derivative of `nothing` instead of zero. By default, this `nothing` + is caught in order to throw an informative error message about a potentially unintentional + misdefined function. However, if this was intentional, setting `allow_nothing=true` will + remove the error message. + +""" +struct ZygoteVJP <: VJPChoice + allow_nothing::Bool +end +ZygoteVJP(; allow_nothing = false) = ZygoteVJP(allow_nothing) + +""" +EnzymeVJP <: VJPChoice + +Uses Enzyme.jl to compute vector-Jacobian products. Is the fastest VJP whenever applicable, +though Enzyme.jl currently has low coverage over the Julia programming language, for example +restricting the user's defined `f` function to not do things like require garbage collection +or calls to BLAS/LAPACK. However, mutation is supported, meaning that in-place `f` with +fully mutating non-allocating code will work with Enzyme (provided no high level calls to C +like BLAS/LAPACK are used) and this will be the most efficient adjoint implementation. + +## Constructor + +```julia +EnzymeVJP(compile=false) +``` +""" +struct EnzymeVJP <: VJPChoice end + +""" +TrackerVJP <: VJPChoice + +Uses Tracker.jl to compute the vector-Jacobian products. If `f` is in-place, +then it uses a array of structs formulation to do scalarized reverse mode, +while if `f` is out-of-place then it uses an array-based reverse mode. + +Not as efficient as `ReverseDiffVJP`, but supports GPUs when doing array-based +reverse mode. + +## Constructor + +```julia +TrackerVJP(;allow_nothing=false) +``` + +Keyword arguments: + +* `allow_nothing`: whether non-tracked values should be implicitly converted to zeros. In Tracker, + the derivative of a function with respect to `p` which does not use `p` in any possible + calculation is given an untracked return instead of zero. By default, this `nothing` Trackedness + is caught in order to throw an informative error message about a potentially unintentional + misdefined function. However, if this was intentional, setting `allow_nothing=true` will + remove the error message. +""" +struct TrackerVJP <: VJPChoice + allow_nothing::Bool +end +TrackerVJP(; allow_nothing = false) = TrackerVJP(allow_nothing) + +""" +ReverseDiffVJP{compile} <: VJPChoice + +Uses ReverseDiff.jl to compute the vector-Jacobian products. If `f` is in-place, +then it uses a array of structs formulation to do scalarized reverse mode, +while if `f` is out-of-place then it uses an array-based reverse mode. + +Usually the fastest when scalarized operations exist in the f function +(like in scientific machine learning applications like Universal Differential Equations) +and the boolean compilation is enabled (i.e. ReverseDiffVJP(true)), if EnzymeVJP fails on +a given choice of `f`. + +Does not support GPUs (CuArrays). + +## Constructor + +```julia +ReverseDiffVJP(compile=false) +``` + +## Keyword Arguments + +* `compile`: Whether to cache the compilation of the reverse tape. This heavily increases + the performance of the method but requires that the `f` function of the ODE/DAE/SDE/DDE + has no branching. +""" +struct ReverseDiffVJP{compile} <: VJPChoice + ReverseDiffVJP(compile = false) = new{compile}() +end + +@inline convert_tspan(::ForwardDiffSensitivity{CS, CTS}) where {CS, CTS} = CTS +@inline convert_tspan(::Any) = nothing +@inline function alg_autodiff(alg::DiffEqBase.AbstractSensitivityAlgorithm{CS, AD, FDT}) where { + CS, + AD, + FDT + } + AD +end +@inline function get_chunksize(alg::DiffEqBase.AbstractSensitivityAlgorithm{CS, AD, FDT}) where { + CS, + AD, + FDT + } + CS +end +@inline function diff_type(alg::DiffEqBase.AbstractSensitivityAlgorithm{CS, AD, FDT}) where { + CS, + AD, + FDT + } + FDT +end +@inline function get_jacvec(alg::DiffEqBase.AbstractSensitivityAlgorithm) + alg.autojacvec isa Bool ? alg.autojacvec : true +end +@inline function get_jacmat(alg::DiffEqBase.AbstractSensitivityAlgorithm) + alg.autojacmat isa Bool ? alg.autojacmat : true +end +@inline ischeckpointing(alg::DiffEqBase.AbstractSensitivityAlgorithm, sol = nothing) = false +@inline ischeckpointing(alg::InterpolatingAdjoint) = alg.checkpointing +@inline ischeckpointing(alg::InterpolatingAdjoint, sol) = alg.checkpointing || !sol.dense +@inline ischeckpointing(alg::BacksolveAdjoint, sol = nothing) = alg.checkpointing + +@inline isnoisemixing(alg::DiffEqBase.AbstractSensitivityAlgorithm) = false +@inline isnoisemixing(alg::InterpolatingAdjoint) = alg.noisemixing +@inline isnoisemixing(alg::BacksolveAdjoint) = alg.noisemixing + +@inline compile_tape(vjp::ReverseDiffVJP{compile}) where {compile} = compile +@inline compile_tape(autojacvec::Bool) = false + +""" +ForwardDiffOverAdjoint{A} <: AbstractSecondOrderSensitivityAlgorithm{nothing,true,nothing} + +ForwardDiff.jl over a choice of `sensealg` method for the adjoint. + +## Constructor + +```julia +ForwardDiffOverAdjoint(sensealg) +``` + +## SciMLProblem Support + +This supports any SciMLProblem that the `sensealg` choice supports, provided the solver algorithm +is `SciMLBase.isautodifferentiable`. + +## References + +Hindmarsh, A. C. and Brown, P. N. and Grant, K. E. and Lee, S. L. and Serban, R. +and Shumaker, D. E. and Woodward, C. S., SUNDIALS: Suite of nonlinear and +differential/algebraic equation solvers, ACM Transactions on Mathematical +Software (TOMS), 31, pp:363–396 (2005) +""" +struct ForwardDiffOverAdjoint{A} <: + AbstractSecondOrderSensitivityAlgorithm{nothing, true, nothing} + adjalg::A +end diff --git a/src/sensitivity_interface.jl b/src/sensitivity_interface.jl index 9344669df..d9027019e 100644 --- a/src/sensitivity_interface.jl +++ b/src/sensitivity_interface.jl @@ -1,479 +1,480 @@ -## Direct calls - -const ADJOINT_PARAMETER_COMPATABILITY_MESSAGE = -""" -Adjoint sensitivity analysis functionality requires being able to solve -a differential equation defined by the parameter struct `p`. Thus while -DifferentialEquations.jl can support any parameter struct type, usage -with adjoint sensitivity analysis requires that `p` could be a valid -type for being the initial condition `u0` of an array. This means that -many simple types, such as `Tuple`s and `NamedTuple`s, will work as -parameters in normal contexts but will fail during adjoint differentiation. -To work around this issue for complicated cases like nested structs, look -into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl -or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type. -""" - -struct AdjointSensitivityParameterCompatibilityError <: Exception end - -function Base.showerror(io::IO, e::AdjointSensitivityParameterCompatibilityError) - print(io, ADJOINT_PARAMETER_COMPATABILITY_MESSAGE) -end - -@doc doc""" -```julia -adjoint_sensitivities(sol,alg;t=nothing,dg_discrete=nothing, - dg_continuous=nothing,g=nothing, - abstol=1e-6,reltol=1e-3, - checkpoints=sol.t, - corfunc_analytical=nothing, - callback = nothing, - sensealg=InterpolatingAdjoint(), - kwargs...) -``` - -Adjoint sensitivity analysis is used to find the gradient of the solution -with respect to some functional of the solution. In many cases this is used -in an optimization problem to return the gradient with respect to some cost -function. It is equivalent to "backpropagation" or reverse-mode automatic -differentiation of a differential equation. - -Using `adjoint_sensitivities` directly let's you do three things. One it can -allow you to be more efficient, since the sensitivity calculation can be done -directly on a cost function, avoiding the overhead of building the derivative -of the full concretized solution. It can also allow you to be more efficient -by directly controlling the forward solve that is then reversed over. Lastly, -it allows one to define a continuous cost function on the continuous solution, -instead of just at discrete data points. - -!!! warning - - Adjoint sensitivity analysis functionality requires being able to solve - a differential equation defined by the parameter struct `p`. Thus while - DifferentialEquations.jl can support any parameter struct type, usage - with adjoint sensitivity analysis requires that `p` could be a valid - type for being the initial condition `u0` of an array. This means that - many simple types, such as `Tuple`s and `NamedTuple`s, will work as - parameters in normal contexts but will fail during adjoint differentiation. - To work around this issue for complicated cases like nested structs, look - into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl - or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type. - -!!! warning - - Non-checkpointed InterpolatingAdjoint and QuadratureAdjoint sensealgs - require that the forward solution `sol(t)` has an accurate dense - solution unless checkpointing is used. This means that you should - not use `solve(prob,alg,saveat=ts)` unless checkpointing. If specific - saving is required, one should solve dense `solve(prob,alg)`, use the - solution in the adjoint, and then `sol(ts)` interpolate. - -### Syntax - -For discrete adjoints, use: - -```julia -du0,dp = adjoint_sensitivities(sol,alg;t=ts,dg_discrete=dg,sensealg=InterpolatingAdjoint(), - checkpoints=sol.t,kwargs...) -``` - -where `alg` is the ODE algorithm to solve the adjoint problem, `dg` is the jump -function, `sensealg` is the sensitivity algorithm, and `ts` is the time points -for data. `dg` is given by: - -```julia -dg(out,u,p,t,i) -``` - -which is the in-place gradient of the cost functional `g` at time point `ts[i]` -with `u=u(t)`. - -For continuous functionals, the form is: - -```julia -du0,dp = adjoint_sensitivities(sol,alg;dg_continuous=(dgdu,dgdp),g=g,sensealg=InterpolatingAdjoint(), - checkpoints=sol.t,kwargs...) -``` - -for the cost functional - -```julia -g(u,p,t) -``` - -with in-place gradient - -```julia -dgdu(out,u,p,t) -dgdp(out,u,p,t) -``` - -If the gradient is omitted, i.e. - -```julia -du0,dp = adjoint_sensitivities(sol,alg;g=g,kwargs...) -``` - -then we assume `dgdp` is zero and `dgdu` will be computed automatically using ForwardDiff or finite -differencing, depending on the `autodiff` setting in the `AbstractSensitivityAlgorithm`. -Note that the keyword arguments are passed to the internal ODE solver for -solving the adjoint problem. - -### Example discrete adjoints on a cost function - -In this example we will show solving for the adjoint sensitivities of a discrete -cost functional. First let's solve the ODE and get a high quality continuous -solution: - -```julia -function f(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] - du[2] = dy = -p[3]*u[2] + u[1]*u[2] -end - -p = [1.5,1.0,3.0] -prob = ODEProblem(f,[1.0;1.0],(0.0,10.0),p) -sol = solve(prob,Vern9(),abstol=1e-10,reltol=1e-10) -``` - -Now let's calculate the sensitivity of the ``\ell_2`` error against 1 at evenly spaced -points in time, that is: - -```math -L(u,p,t)=\sum_{i=1}^{n}\frac{\Vert1-u(t_{i},p)\Vert^{2}}{2} -``` - -for ``t_i = 0.5i``. This is the assumption that the data is `data[i]=1.0`. -For this function, notice we have that: - -```math -\begin{aligned} -dg_{1}&=-1+u_{1} \\ -dg_{2}&=-1+u_{2} \\ -& \quad \vdots -\end{aligned} -``` - -and thus: - -```julia -dg(out,u,p,t,i) = (out.=-1.0.+u) -``` - -Also, we can omit `dgdp`, because the cost function doesn't dependent on `p`. If we had data, we'd just replace `1.0` with `data[i]`. To get the adjoint -sensitivities, call: - -```julia -ts = 0:0.5:10 -res = adjoint_sensitivities(sol,Vern9();t=ts,dg_discrete=dg,abstol=1e-14, - reltol=1e-14) -``` - -This is super high accuracy. As always, there's a tradeoff between accuracy -and computation time. We can check this almost exactly matches the -autodifferentiation and numerical differentiation results: - -```julia -using ForwardDiff,Calculus,Tracker -function G(p) - tmp_prob = remake(prob,u0=convert.(eltype(p),prob.u0),p=p) - sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=ts, - sensealg=SensitivityADPassThrough()) - A = convert(Array,sol) - sum(((1 .- A).^2)./2) -end -G([1.5,1.0,3.0]) -res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0]) -res3 = Calculus.gradient(G,[1.5,1.0,3.0]) -res4 = Tracker.gradient(G,[1.5,1.0,3.0]) -res5 = ReverseDiff.gradient(G,[1.5,1.0,3.0]) -``` - -and see this gives the same values. - -### Example controlling adjoint method choices and checkpointing - -In the previous examples, all calculations were done using the interpolating -method. This maximizes speed but at a cost of requiring a dense `sol`. If it -is not possible to hold a dense forward solution in memory, then one can use -checkpointing. For example: - -```julia -ts = [0.0,0.2,0.5,0.7] -sol = solve(prob,Vern9(),saveat=ts) -``` - -Creates a non-dense solution with checkpoints at `[0.0,0.2,0.5,0.7]`. Now we -can do: - -```julia -res = adjoint_sensitivities(sol,Vern9();t=ts,dg_discrete=dg, - sensealg=InterpolatingAdjoint(checkpointing=true)) -``` - -When grabbing a Jacobian value during the backwards solution, it will no longer -interpolate to get the value. Instead, it will start a forward solution at the -nearest checkpoint to build local interpolants in a way that conserves memory. -By default the checkpoints are at `sol.t`, but we can override this: - -```julia -res = adjoint_sensitivities(sol,Vern9();t=ts,dg_discrte=dg, - sensealg=InterpolatingAdjoint(checkpointing=true), - checkpoints = [0.0,0.5]) -``` - -### Example continuous adjoints on an energy functional - -In this case we'd like to calculate the adjoint sensitivity of the scalar energy -functional: - -```math -G(u,p)=\int_{0}^{T}\frac{\sum_{i=1}^{n}u_{i}^{2}(t)}{2}dt -``` - -which is: - -```julia -g(u,p,t) = (sum(u).^2) ./ 2 -``` - -Notice that the gradient of this function with respect to the state `u` is: - -```julia -function dg(out,u,p,t) - out[1]= u[1] + u[2] - out[2]= u[1] + u[2] -end -``` - -To get the adjoint sensitivities, we call: - -```julia -res = adjoint_sensitivities(sol,Vern9();dg_continuous=dg,g=g,abstol=1e-8, - reltol=1e-8,iabstol=1e-8,ireltol=1e-8) -``` - -Notice that we can check this against autodifferentiation and numerical -differentiation as follows: - -```julia -using QuadGK -function G(p) - tmp_prob = remake(prob,p=p) - sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14) - res,err = quadgk((t)-> (sum(sol(t)).^2)./2,0.0,10.0,atol=1e-14,rtol=1e-10) - res -end -res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0]) -res3 = Calculus.gradient(G,[1.5,1.0,3.0]) -``` -""" -function adjoint_sensitivities(sol,args...; - sensealg=InterpolatingAdjoint(), - verbose=true,kwargs...) - if hasfield(typeof(sensealg),:autojacvec) && sensealg.autojacvec === nothing - if haskey(kwargs, :callback) - has_cb = kwargs[:callback] !== nothing - else - has_cb = false - end - if !has_cb - _sensealg = if isinplace(sol.prob) - setvjp(sensealg, inplace_vjp(sol.prob, sol.prob.u0, sol.prob.p, verbose)) - else - setvjp(sensealg, ZygoteVJP()) - end - else - _sensealg = setvjp(sensealg, ReverseDiffVJP()) - end - - return try - _adjoint_sensitivities(sol, _sensealg, args...; verbose, kwargs...) - catch e - verbose && @warn "Automatic AD choice of autojacvec failed in ODE adjoint, failing back to ODE adjoint + numerical vjp" - _adjoint_sensitivities(sol, setvjp(sensealg, false), args...; verbose, kwargs...) - end - else - return _adjoint_sensitivities(sol, sensealg, args...; verbose, kwargs...) - end -end - -function _adjoint_sensitivities(sol, sensealg, alg; - t=nothing, - dg_discrete=nothing, dg_continuous=nothing, - g=nothing, - abstol=1e-6, reltol=1e-3, - checkpoints=sol.t, - corfunc_analytical=nothing, - callback=nothing, - kwargs...) - - if !(typeof(sol.prob.p) <: Union{Nothing,SciMLBase.NullParameters,AbstractArray}) || (sol.prob.p isa AbstractArray && !Base.isconcretetype(eltype(sol.prob.p))) - throw(AdjointSensitivityParameterCompatibilityError()) - end - - if sol.prob isa ODEProblem - adj_prob = ODEAdjointProblem(sol, sensealg, t, dg_discrete, dg_continuous, g; - checkpoints=checkpoints, - callback=callback, - abstol=abstol, reltol=reltol, kwargs...) - - elseif sol.prob isa SDEProblem - adj_prob = SDEAdjointProblem(sol, sensealg, t, dg_discrete, dg_continuous, g; - checkpoints=checkpoints, - callback=callback, - abstol=abstol, reltol=reltol, - corfunc_analytical=corfunc_analytical) - elseif sol.prob isa RODEProblem - adj_prob = RODEAdjointProblem(sol, sensealg, t, dg_discrete, dg_continuous, g; - checkpoints=checkpoints, - callback=callback, - abstol=abstol, reltol=reltol, - corfunc_analytical=corfunc_analytical) - else - error("Continuous adjoint sensitivities are only supported for ODE/SDE/RODE problems.") - end - - tstops = ischeckpointing(sensealg, sol) ? checkpoints : similar(sol.t, 0) - adj_sol = solve(adj_prob, alg; - save_everystep=false, save_start=false, saveat=eltype(sol[1])[], - tstops=tstops, abstol=abstol, reltol=reltol, kwargs...) - - p = sol.prob.p - l = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(sol.prob.p) - du0 = adj_sol[end][1:length(sol.prob.u0)] - - if eltype(sol.prob.p) <: real(eltype(adj_sol[end])) - dp = real.(adj_sol[end][(1:l).+length(sol.prob.u0)])' - elseif p === nothing || p === DiffEqBase.NullParameters() - dp = nothing - else - dp = adj_sol[end][(1:l).+length(sol.prob.u0)]' - end - - du0,dp -end - -function _adjoint_sensitivities(sol,sensealg::SteadyStateAdjoint,alg,g,dg=nothing; - abstol=1e-6,reltol=1e-3, - kwargs...) - SteadyStateAdjointProblem(sol,sensealg,g,dg;kwargs...) -end - -function _adjoint_sensitivities(sol,sensealg::SteadyStateAdjoint,alg; - g=nothing,dg=nothing, - abstol=1e-6,reltol=1e-3, - kwargs...) - SteadyStateAdjointProblem(sol,sensealg,g,dg;kwargs...) -end - -@doc doc""" -H = second_order_sensitivities(loss,prob,alg,args...; - sensealg=ForwardDiffOverAdjoint(InterpolatingAdjoint(autojacvec=ReverseDiffVJP())), - kwargs...) - -Second order sensitivity analysis is used for the fast calculation of Hessian -matrices. - -!!! warning - - Adjoint sensitivity analysis functionality requires being able to solve - a differential equation defined by the parameter struct `p`. Thus while - DifferentialEquations.jl can support any parameter struct type, usage - with adjoint sensitivity analysis requires that `p` could be a valid - type for being the initial condition `u0` of an array. This means that - many simple types, such as `Tuple`s and `NamedTuple`s, will work as - parameters in normal contexts but will fail during adjoint differentiation. - To work around this issue for complicated cases like nested structs, look - into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl - or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type. - -### Example second order sensitivity analysis calculation - -```julia -using SciMLSensitivity, OrdinaryDiffEq, ForwardDiff -using Test - -function lotka!(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] - du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] -end - -p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0] -prob = ODEProblem(lotka!,u0,(0.0,10.0),p) -loss(sol) = sum(sol) -v = ones(4) - -H = second_order_sensitivities(loss,prob,Vern9(),saveat=0.1,abstol=1e-12,reltol=1e-12) -``` - -## Arguments - -The arguments for this function match `adjoint_sensitivities`. The only notable difference -is `sensealg` which requires a second order sensitivity algorithm, of which currently the -only choice is `ForwardDiffOverAdjoint` which uses forward-over-reverse to mix a forward-mode -sensitivity analysis with an adjoint sensitivity analysis for a faster computation than either -double forward or double reverse. `ForwardDiffOverAdjoint`'s positional argument just accepts -a first order sensitivity algorithm. -""" -function second_order_sensitivities(loss,prob,alg,args...; - sensealg=ForwardDiffOverAdjoint(InterpolatingAdjoint(autojacvec=ReverseDiffVJP())), - kwargs...) - _second_order_sensitivities(loss,prob,alg,sensealg,args...;kwargs...) -end - -@doc doc""" -Hv = second_order_sensitivity_product(loss,v,prob,alg,args...; - sensealg=ForwardDiffOverAdjoint(InterpolatingAdjoint(autojacvec=ReverseDiffVJP())), - kwargs...) - -Second order sensitivity analysis product is used for the fast calculation of -Hessian-vector products ``Hv`` without requiring the construction of the Hessian -matrix. - -!!! warning - - Adjoint sensitivity analysis functionality requires being able to solve - a differential equation defined by the parameter struct `p`. Thus while - DifferentialEquations.jl can support any parameter struct type, usage - with adjoint sensitivity analysis requires that `p` could be a valid - type for being the initial condition `u0` of an array. This means that - many simple types, such as `Tuple`s and `NamedTuple`s, will work as - parameters in normal contexts but will fail during adjoint differentiation. - To work around this issue for complicated cases like nested structs, look - into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl - or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type. - -### Example second order sensitivity analysis calculation - -```julia -using SciMLSensitivity, OrdinaryDiffEq, ForwardDiff -using Test - -function lotka!(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] - du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] -end - -p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0] -prob = ODEProblem(lotka!,u0,(0.0,10.0),p) -loss(sol) = sum(sol) -v = ones(4) - -Hv = second_order_sensitivity_product(loss,v,prob,Vern9(),saveat=0.1,abstol=1e-12,reltol=1e-12) -``` - -## Arguments - -The arguments for this function match `adjoint_sensitivities`. The only notable difference -is `sensealg` which requires a second order sensitivity algorithm, of which currently the -only choice is `ForwardDiffOverAdjoint` which uses forward-over-reverse to mix a forward-mode -sensitivity analysis with an adjoint sensitivity analysis for a faster computation than either -double forward or double reverse. `ForwardDiffOverAdjoint`'s positional argument just accepts -a first order sensitivity algorithm. -""" -function second_order_sensitivity_product(loss,v,prob,alg,args...; - sensealg=ForwardDiffOverAdjoint(InterpolatingAdjoint(autojacvec=ReverseDiffVJP())), - kwargs...) - _second_order_sensitivity_product(loss,v,prob,alg,sensealg,args...;kwargs...) -end +## Direct calls + +const ADJOINT_PARAMETER_COMPATABILITY_MESSAGE = """ + Adjoint sensitivity analysis functionality requires being able to solve + a differential equation defined by the parameter struct `p`. Thus while + DifferentialEquations.jl can support any parameter struct type, usage + with adjoint sensitivity analysis requires that `p` could be a valid + type for being the initial condition `u0` of an array. This means that + many simple types, such as `Tuple`s and `NamedTuple`s, will work as + parameters in normal contexts but will fail during adjoint differentiation. + To work around this issue for complicated cases like nested structs, look + into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl + or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type. + """ + +struct AdjointSensitivityParameterCompatibilityError <: Exception end + +function Base.showerror(io::IO, e::AdjointSensitivityParameterCompatibilityError) + print(io, ADJOINT_PARAMETER_COMPATABILITY_MESSAGE) +end + +@doc doc""" +```julia +adjoint_sensitivities(sol,alg;t=nothing,dg_discrete=nothing, + dg_continuous=nothing,g=nothing, + abstol=1e-6,reltol=1e-3, + checkpoints=sol.t, + corfunc_analytical=nothing, + callback = nothing, + sensealg=InterpolatingAdjoint(), + kwargs...) +``` + +Adjoint sensitivity analysis is used to find the gradient of the solution +with respect to some functional of the solution. In many cases this is used +in an optimization problem to return the gradient with respect to some cost +function. It is equivalent to "backpropagation" or reverse-mode automatic +differentiation of a differential equation. + +Using `adjoint_sensitivities` directly let's you do three things. One it can +allow you to be more efficient, since the sensitivity calculation can be done +directly on a cost function, avoiding the overhead of building the derivative +of the full concretized solution. It can also allow you to be more efficient +by directly controlling the forward solve that is then reversed over. Lastly, +it allows one to define a continuous cost function on the continuous solution, +instead of just at discrete data points. + +!!! warning + + Adjoint sensitivity analysis functionality requires being able to solve + a differential equation defined by the parameter struct `p`. Thus while + DifferentialEquations.jl can support any parameter struct type, usage + with adjoint sensitivity analysis requires that `p` could be a valid + type for being the initial condition `u0` of an array. This means that + many simple types, such as `Tuple`s and `NamedTuple`s, will work as + parameters in normal contexts but will fail during adjoint differentiation. + To work around this issue for complicated cases like nested structs, look + into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl + or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type. + +!!! warning + + Non-checkpointed InterpolatingAdjoint and QuadratureAdjoint sensealgs + require that the forward solution `sol(t)` has an accurate dense + solution unless checkpointing is used. This means that you should + not use `solve(prob,alg,saveat=ts)` unless checkpointing. If specific + saving is required, one should solve dense `solve(prob,alg)`, use the + solution in the adjoint, and then `sol(ts)` interpolate. + +### Syntax + +For discrete adjoints, use: + +```julia +du0,dp = adjoint_sensitivities(sol,alg;t=ts,dg_discrete=dg,sensealg=InterpolatingAdjoint(), + checkpoints=sol.t,kwargs...) +``` + +where `alg` is the ODE algorithm to solve the adjoint problem, `dg` is the jump +function, `sensealg` is the sensitivity algorithm, and `ts` is the time points +for data. `dg` is given by: + +```julia +dg(out,u,p,t,i) +``` + +which is the in-place gradient of the cost functional `g` at time point `ts[i]` +with `u=u(t)`. + +For continuous functionals, the form is: + +```julia +du0,dp = adjoint_sensitivities(sol,alg;dg_continuous=(dgdu,dgdp),g=g,sensealg=InterpolatingAdjoint(), + checkpoints=sol.t,kwargs...) +``` + +for the cost functional + +```julia +g(u,p,t) +``` + +with in-place gradient + +```julia +dgdu(out,u,p,t) +dgdp(out,u,p,t) +``` + +If the gradient is omitted, i.e. + +```julia +du0,dp = adjoint_sensitivities(sol,alg;g=g,kwargs...) +``` + +then we assume `dgdp` is zero and `dgdu` will be computed automatically using ForwardDiff or finite +differencing, depending on the `autodiff` setting in the `AbstractSensitivityAlgorithm`. +Note that the keyword arguments are passed to the internal ODE solver for +solving the adjoint problem. + +### Example discrete adjoints on a cost function + +In this example we will show solving for the adjoint sensitivities of a discrete +cost functional. First let's solve the ODE and get a high quality continuous +solution: + +```julia +function f(du,u,p,t) + du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] + du[2] = dy = -p[3]*u[2] + u[1]*u[2] +end + +p = [1.5,1.0,3.0] +prob = ODEProblem(f,[1.0;1.0],(0.0,10.0),p) +sol = solve(prob,Vern9(),abstol=1e-10,reltol=1e-10) +``` + +Now let's calculate the sensitivity of the ``\ell_2`` error against 1 at evenly spaced +points in time, that is: + +```math +L(u,p,t)=\sum_{i=1}^{n}\frac{\Vert1-u(t_{i},p)\Vert^{2}}{2} +``` + +for ``t_i = 0.5i``. This is the assumption that the data is `data[i]=1.0`. +For this function, notice we have that: + +```math +\begin{aligned} +dg_{1}&=-1+u_{1} \\ +dg_{2}&=-1+u_{2} \\ +& \quad \vdots +\end{aligned} +``` + +and thus: + +```julia +dg(out,u,p,t,i) = (out.=-1.0.+u) +``` + +Also, we can omit `dgdp`, because the cost function doesn't dependent on `p`. If we had data, we'd just replace `1.0` with `data[i]`. To get the adjoint +sensitivities, call: + +```julia +ts = 0:0.5:10 +res = adjoint_sensitivities(sol,Vern9();t=ts,dg_discrete=dg,abstol=1e-14, + reltol=1e-14) +``` + +This is super high accuracy. As always, there's a tradeoff between accuracy +and computation time. We can check this almost exactly matches the +autodifferentiation and numerical differentiation results: + +```julia +using ForwardDiff,Calculus,Tracker +function G(p) + tmp_prob = remake(prob,u0=convert.(eltype(p),prob.u0),p=p) + sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=ts, + sensealg=SensitivityADPassThrough()) + A = convert(Array,sol) + sum(((1 .- A).^2)./2) +end +G([1.5,1.0,3.0]) +res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0]) +res3 = Calculus.gradient(G,[1.5,1.0,3.0]) +res4 = Tracker.gradient(G,[1.5,1.0,3.0]) +res5 = ReverseDiff.gradient(G,[1.5,1.0,3.0]) +``` + +and see this gives the same values. + +### Example controlling adjoint method choices and checkpointing + +In the previous examples, all calculations were done using the interpolating +method. This maximizes speed but at a cost of requiring a dense `sol`. If it +is not possible to hold a dense forward solution in memory, then one can use +checkpointing. For example: + +```julia +ts = [0.0,0.2,0.5,0.7] +sol = solve(prob,Vern9(),saveat=ts) +``` + +Creates a non-dense solution with checkpoints at `[0.0,0.2,0.5,0.7]`. Now we +can do: + +```julia +res = adjoint_sensitivities(sol,Vern9();t=ts,dg_discrete=dg, + sensealg=InterpolatingAdjoint(checkpointing=true)) +``` + +When grabbing a Jacobian value during the backwards solution, it will no longer +interpolate to get the value. Instead, it will start a forward solution at the +nearest checkpoint to build local interpolants in a way that conserves memory. +By default the checkpoints are at `sol.t`, but we can override this: + +```julia +res = adjoint_sensitivities(sol,Vern9();t=ts,dg_discrte=dg, + sensealg=InterpolatingAdjoint(checkpointing=true), + checkpoints = [0.0,0.5]) +``` + +### Example continuous adjoints on an energy functional + +In this case we'd like to calculate the adjoint sensitivity of the scalar energy +functional: + +```math +G(u,p)=\int_{0}^{T}\frac{\sum_{i=1}^{n}u_{i}^{2}(t)}{2}dt +``` + +which is: + +```julia +g(u,p,t) = (sum(u).^2) ./ 2 +``` + +Notice that the gradient of this function with respect to the state `u` is: + +```julia +function dg(out,u,p,t) + out[1]= u[1] + u[2] + out[2]= u[1] + u[2] +end +``` + +To get the adjoint sensitivities, we call: + +```julia +res = adjoint_sensitivities(sol,Vern9();dg_continuous=dg,g=g,abstol=1e-8, + reltol=1e-8,iabstol=1e-8,ireltol=1e-8) +``` + +Notice that we can check this against autodifferentiation and numerical +differentiation as follows: + +```julia +using QuadGK +function G(p) + tmp_prob = remake(prob,p=p) + sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14) + res,err = quadgk((t)-> (sum(sol(t)).^2)./2,0.0,10.0,atol=1e-14,rtol=1e-10) + res +end +res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0]) +res3 = Calculus.gradient(G,[1.5,1.0,3.0]) +``` +""" +function adjoint_sensitivities(sol, args...; + sensealg = InterpolatingAdjoint(), + verbose = true, kwargs...) + if hasfield(typeof(sensealg), :autojacvec) && sensealg.autojacvec === nothing + if haskey(kwargs, :callback) + has_cb = kwargs[:callback] !== nothing + else + has_cb = false + end + if !has_cb + _sensealg = if isinplace(sol.prob) + setvjp(sensealg, inplace_vjp(sol.prob, sol.prob.u0, sol.prob.p, verbose)) + else + setvjp(sensealg, ZygoteVJP()) + end + else + _sensealg = setvjp(sensealg, ReverseDiffVJP()) + end + + return try + _adjoint_sensitivities(sol, _sensealg, args...; verbose, kwargs...) + catch e + verbose && + @warn "Automatic AD choice of autojacvec failed in ODE adjoint, failing back to ODE adjoint + numerical vjp" + _adjoint_sensitivities(sol, setvjp(sensealg, false), args...; verbose, + kwargs...) + end + else + return _adjoint_sensitivities(sol, sensealg, args...; verbose, kwargs...) + end +end + +function _adjoint_sensitivities(sol, sensealg, alg; + t = nothing, + dg_discrete = nothing, dg_continuous = nothing, + g = nothing, + abstol = 1e-6, reltol = 1e-3, + checkpoints = sol.t, + corfunc_analytical = nothing, + callback = nothing, + kwargs...) + if !(typeof(sol.prob.p) <: Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || + (sol.prob.p isa AbstractArray && !Base.isconcretetype(eltype(sol.prob.p))) + throw(AdjointSensitivityParameterCompatibilityError()) + end + + if sol.prob isa ODEProblem + adj_prob = ODEAdjointProblem(sol, sensealg, t, dg_discrete, dg_continuous, g; + checkpoints = checkpoints, + callback = callback, + abstol = abstol, reltol = reltol, kwargs...) + + elseif sol.prob isa SDEProblem + adj_prob = SDEAdjointProblem(sol, sensealg, t, dg_discrete, dg_continuous, g; + checkpoints = checkpoints, + callback = callback, + abstol = abstol, reltol = reltol, + corfunc_analytical = corfunc_analytical) + elseif sol.prob isa RODEProblem + adj_prob = RODEAdjointProblem(sol, sensealg, t, dg_discrete, dg_continuous, g; + checkpoints = checkpoints, + callback = callback, + abstol = abstol, reltol = reltol, + corfunc_analytical = corfunc_analytical) + else + error("Continuous adjoint sensitivities are only supported for ODE/SDE/RODE problems.") + end + + tstops = ischeckpointing(sensealg, sol) ? checkpoints : similar(sol.t, 0) + adj_sol = solve(adj_prob, alg; + save_everystep = false, save_start = false, saveat = eltype(sol[1])[], + tstops = tstops, abstol = abstol, reltol = reltol, kwargs...) + + p = sol.prob.p + l = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(sol.prob.p) + du0 = adj_sol[end][1:length(sol.prob.u0)] + + if eltype(sol.prob.p) <: real(eltype(adj_sol[end])) + dp = real.(adj_sol[end][(1:l) .+ length(sol.prob.u0)])' + elseif p === nothing || p === DiffEqBase.NullParameters() + dp = nothing + else + dp = adj_sol[end][(1:l) .+ length(sol.prob.u0)]' + end + + du0, dp +end + +function _adjoint_sensitivities(sol, sensealg::SteadyStateAdjoint, alg, g, dg = nothing; + abstol = 1e-6, reltol = 1e-3, + kwargs...) + SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) +end + +function _adjoint_sensitivities(sol, sensealg::SteadyStateAdjoint, alg; + g = nothing, dg = nothing, + abstol = 1e-6, reltol = 1e-3, + kwargs...) + SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) +end + +@doc doc""" +H = second_order_sensitivities(loss,prob,alg,args...; + sensealg=ForwardDiffOverAdjoint(InterpolatingAdjoint(autojacvec=ReverseDiffVJP())), + kwargs...) + +Second order sensitivity analysis is used for the fast calculation of Hessian +matrices. + +!!! warning + + Adjoint sensitivity analysis functionality requires being able to solve + a differential equation defined by the parameter struct `p`. Thus while + DifferentialEquations.jl can support any parameter struct type, usage + with adjoint sensitivity analysis requires that `p` could be a valid + type for being the initial condition `u0` of an array. This means that + many simple types, such as `Tuple`s and `NamedTuple`s, will work as + parameters in normal contexts but will fail during adjoint differentiation. + To work around this issue for complicated cases like nested structs, look + into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl + or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type. + +### Example second order sensitivity analysis calculation + +```julia +using SciMLSensitivity, OrdinaryDiffEq, ForwardDiff +using Test + +function lotka!(du,u,p,t) + du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] + du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] +end + +p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0] +prob = ODEProblem(lotka!,u0,(0.0,10.0),p) +loss(sol) = sum(sol) +v = ones(4) + +H = second_order_sensitivities(loss,prob,Vern9(),saveat=0.1,abstol=1e-12,reltol=1e-12) +``` + +## Arguments + +The arguments for this function match `adjoint_sensitivities`. The only notable difference +is `sensealg` which requires a second order sensitivity algorithm, of which currently the +only choice is `ForwardDiffOverAdjoint` which uses forward-over-reverse to mix a forward-mode +sensitivity analysis with an adjoint sensitivity analysis for a faster computation than either +double forward or double reverse. `ForwardDiffOverAdjoint`'s positional argument just accepts +a first order sensitivity algorithm. +""" +function second_order_sensitivities(loss, prob, alg, args...; + sensealg = ForwardDiffOverAdjoint(InterpolatingAdjoint(autojacvec = ReverseDiffVJP())), + kwargs...) + _second_order_sensitivities(loss, prob, alg, sensealg, args...; kwargs...) +end + +@doc doc""" +Hv = second_order_sensitivity_product(loss,v,prob,alg,args...; + sensealg=ForwardDiffOverAdjoint(InterpolatingAdjoint(autojacvec=ReverseDiffVJP())), + kwargs...) + +Second order sensitivity analysis product is used for the fast calculation of +Hessian-vector products ``Hv`` without requiring the construction of the Hessian +matrix. + +!!! warning + + Adjoint sensitivity analysis functionality requires being able to solve + a differential equation defined by the parameter struct `p`. Thus while + DifferentialEquations.jl can support any parameter struct type, usage + with adjoint sensitivity analysis requires that `p` could be a valid + type for being the initial condition `u0` of an array. This means that + many simple types, such as `Tuple`s and `NamedTuple`s, will work as + parameters in normal contexts but will fail during adjoint differentiation. + To work around this issue for complicated cases like nested structs, look + into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl + or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type. + +### Example second order sensitivity analysis calculation + +```julia +using SciMLSensitivity, OrdinaryDiffEq, ForwardDiff +using Test + +function lotka!(du,u,p,t) + du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] + du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] +end + +p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0] +prob = ODEProblem(lotka!,u0,(0.0,10.0),p) +loss(sol) = sum(sol) +v = ones(4) + +Hv = second_order_sensitivity_product(loss,v,prob,Vern9(),saveat=0.1,abstol=1e-12,reltol=1e-12) +``` + +## Arguments + +The arguments for this function match `adjoint_sensitivities`. The only notable difference +is `sensealg` which requires a second order sensitivity algorithm, of which currently the +only choice is `ForwardDiffOverAdjoint` which uses forward-over-reverse to mix a forward-mode +sensitivity analysis with an adjoint sensitivity analysis for a faster computation than either +double forward or double reverse. `ForwardDiffOverAdjoint`'s positional argument just accepts +a first order sensitivity algorithm. +""" +function second_order_sensitivity_product(loss, v, prob, alg, args...; + sensealg = ForwardDiffOverAdjoint(InterpolatingAdjoint(autojacvec = ReverseDiffVJP())), + kwargs...) + _second_order_sensitivity_product(loss, v, prob, alg, sensealg, args...; kwargs...) +end diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index 1f23be561..4c4510d65 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -1,14 +1,14 @@ struct SteadyStateAdjointSensitivityFunction{ - C<:AdjointDiffCache, - Alg<:SteadyStateAdjoint, - uType, - SType, - fType<:ODEFunction, - CV, - λType, - VJPType, - LS, -} <: SensitivityFunction + C <: AdjointDiffCache, + Alg <: SteadyStateAdjoint, + uType, + SType, + fType <: ODEFunction, + CV, + λType, + VJPType, + LS + } <: SensitivityFunction diffcache::C sensealg::Alg discrete::Bool @@ -21,75 +21,65 @@ struct SteadyStateAdjointSensitivityFunction{ linsolve::LS end -function SteadyStateAdjointSensitivityFunction( - g, - sensealg, - discrete, - sol, - dg, - colorvec, - needs_jac, -) +function SteadyStateAdjointSensitivityFunction(g, + sensealg, + discrete, + sol, + dg, + colorvec, + needs_jac) @unpack f, p, u0 = sol.prob - diffcache, y = adjointdiffcache( - g, - sensealg, - discrete, - sol, - dg, - f; - quad = false, - needs_jac = needs_jac, - ) + diffcache, y = adjointdiffcache(g, + sensealg, + discrete, + sol, + dg, + f; + quad = false, + needs_jac = needs_jac) λ = zero(y) linsolve = needs_jac ? nothing : sensealg.linsolve vjp = similar(λ, length(p)) - SteadyStateAdjointSensitivityFunction( - diffcache, - sensealg, - discrete, - y, - sol, - f, - colorvec, - λ, - vjp, - linsolve, - ) + SteadyStateAdjointSensitivityFunction(diffcache, + sensealg, + discrete, + y, + sol, + f, + colorvec, + λ, + vjp, + linsolve) end -@noinline function SteadyStateAdjointProblem( - sol, - sensealg::SteadyStateAdjoint, - g, - dg; - save_idxs = nothing, - kwargs... -) +@noinline function SteadyStateAdjointProblem(sol, + sensealg::SteadyStateAdjoint, + g, + dg; + save_idxs = nothing, + kwargs...) @unpack f, p, u0 = sol.prob discrete = false # TODO: What is the correct heuristic? Can we afford to compute Jacobian for # cases where the length(u0) > 50 and if yes till what threshold - needs_jac = (sensealg.linsolve === nothing && length(u0) <= 50) || LinearSolve.needs_concrete_A(sensealg.linsolve) - - p === DiffEqBase.NullParameters() && error( - "Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!", - ) - - sense = SteadyStateAdjointSensitivityFunction( - g, - sensealg, - discrete, - sol, - dg, - f.colorvec, - needs_jac, - ) + needs_jac = (sensealg.linsolve === nothing && length(u0) <= 50) || + LinearSolve.needs_concrete_A(sensealg.linsolve) + + p === DiffEqBase.NullParameters() && + error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!") + + sense = SteadyStateAdjointSensitivityFunction(g, + sensealg, + discrete, + sol, + dg, + f.colorvec, + needs_jac) @unpack diffcache, y, sol, λ, vjp, linsolve = sense if needs_jac @@ -97,14 +87,12 @@ end f.jac(diffcache.J, y, p, nothing) else if DiffEqBase.isinplace(sol.prob) - jacobian!( - diffcache.J, - diffcache.uf, - y, - diffcache.f_cache, - sensealg, - diffcache.jac_config, - ) + jacobian!(diffcache.J, + diffcache.uf, + y, + diffcache.f_cache, + sensealg, + diffcache.jac_config) else temp = jacobian(diffcache.uf, y, sensealg) @. diffcache.J = temp @@ -127,42 +115,40 @@ end end else if g !== nothing - gradient!( - vec(diffcache.dg_val), - diffcache.g, - y, - sensealg, - diffcache.g_grad_config, - ) + gradient!(vec(diffcache.dg_val), + diffcache.g, + y, + sensealg, + diffcache.g_grad_config) end end if !needs_jac # NOTE: Zygote doesn't support inplace - linear_problem = LinearProblem(VecJacOperator(f, y, p; autodiff = !DiffEqBase.isinplace(sol.prob)), + linear_problem = LinearProblem(VecJacOperator(f, y, p; + autodiff = !DiffEqBase.isinplace(sol.prob)), vec(diffcache.dg_val), u0 = vec(λ)) else - linear_problem = LinearProblem(diffcache.J',vec(diffcache.dg_val'),u0 = vec(λ)) + linear_problem = LinearProblem(diffcache.J', vec(diffcache.dg_val'), u0 = vec(λ)) end solve(linear_problem, linsolve) # u is vec(λ) try - vecjacobian!( - vec(diffcache.dg_val), - y, - λ, - p, - nothing, - sense, - dgrad = vjp, - dy = nothing - ) + vecjacobian!(vec(diffcache.dg_val), + y, + λ, + p, + nothing, + sense, + dgrad = vjp, + dy = nothing) catch e if sense.sensealg.autojacvec === nothing @warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to ODE adjoint + numerical vjp" - vecjacobian!(vec(diffcache.dg_val),y,λ,p,nothing,false,dgrad = vjp,dy = nothing) + vecjacobian!(vec(diffcache.dg_val), y, λ, p, nothing, false, dgrad = vjp, + dy = nothing) else @warn "AD choice of autojacvec failed in nonlinear solve adjoint" throw(e) diff --git a/src/tracker.jl b/src/tracker.jl index 3e47b74dd..77977284f 100644 --- a/src/tracker.jl +++ b/src/tracker.jl @@ -1,24 +1,38 @@ # Piracy that used to be requires, allowing Tracker.jl to be specialized for SciML -function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T,N}, a::AbstractArray{T2,N}) where {T<:Tracker.TrackedArray,T2<:Tracker.TrackedArray,N} +function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N}, + a::AbstractArray{T2, N}) where { + T <: + Tracker.TrackedArray, + T2 <: + Tracker.TrackedArray, + N} @inbounds for i in eachindex(a) b[i] = copy(a[i]) end end DiffEqBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T -DiffEqBase.value(x::Type{Tracker.TrackedArray{T,N,A}}) where {T,N,A} = Array{T,N} +DiffEqBase.value(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array{T, N} DiffEqBase.value(x::Tracker.TrackedReal) = x.data DiffEqBase.value(x::Tracker.TrackedArray) = x.data DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0 -DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, p::Tracker.TrackedArray, t0) = u0 -DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::AbstractArray{<:Tracker.TrackedReal}, t0) = u0 -DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, p::AbstractArray{<:Tracker.TrackedReal}, t0) = u0 +function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, + p::Tracker.TrackedArray, t0) + u0 +end +function DiffEqBase.promote_u0(u0::Tracker.TrackedArray, + p::AbstractArray{<:Tracker.TrackedReal}, t0) + u0 +end +function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, + p::AbstractArray{<:Tracker.TrackedReal}, t0) + u0 +end DiffEqBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0) DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p).(u0) - @inline DiffEqBase.fastpow(x::Tracker.TrackedReal, y::Tracker.TrackedReal) = x^y @inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x)) @@ -26,41 +40,62 @@ DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype( @inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t) where {N} sqrt(sum(abs2, DiffEqBase.value(u)) / length(u)) end -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N}, t) where {N} - sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N}, + t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) end -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N}, t) where {N} - sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N}, + t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) end @inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t) = abs(DiffEqBase.value(u)) # Support TrackedReal time, don't drop tracking on the adaptivity there -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t::Tracker.TrackedReal) where {N} +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, + t::Tracker.TrackedReal) where {N} sqrt(sum(abs2, u) / length(u)) end -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N}, t::Tracker.TrackedReal) where {N} - sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u)) +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N}, + t::Tracker.TrackedReal) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / + length(u)) end -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N}, t::Tracker.TrackedReal) where {N} - sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u)) +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N}, + t::Tracker.TrackedReal) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / + length(u)) end @inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t::Tracker.TrackedReal) = abs(u) -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0::Tracker.TrackedArray, p::Tracker.TrackedArray, args...; kwargs...) +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm, + Nothing}, u0::Tracker.TrackedArray, + p::Tracker.TrackedArray, args...; kwargs...) Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0::Tracker.TrackedArray, p, args...; kwargs...) +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm, + Nothing}, u0::Tracker.TrackedArray, p, args...; + kwargs...) Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0, p::Tracker.TrackedArray, args...; kwargs...) +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm, + Nothing}, u0, p::Tracker.TrackedArray, args...; + kwargs...) Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -Tracker.@grad function DiffEqBase.solve_up(prob, sensealg::Union{Nothing,DiffEqBase.AbstractSensitivityAlgorithm}, - u0, p, args...; - kwargs...) - DiffEqBase._solve_adjoint(prob, sensealg, Tracker.data(u0), Tracker.data(p), +Tracker.@grad function DiffEqBase.solve_up(prob, + sensealg::Union{Nothing, + DiffEqBase.AbstractSensitivityAlgorithm + }, + u0, p, args...; + kwargs...) + DiffEqBase._solve_adjoint(prob, sensealg, Tracker.data(u0), Tracker.data(p), SciMLBase.TrackerOriginator(), args...; kwargs...) end diff --git a/src/zygote.jl b/src/zygote.jl index 9c4a6e9d1..7da2bd152 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -17,25 +17,32 @@ function ∇tmap(cx, f, args...) end function ∇responsible_map(cx, f, args...) - ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...), args...) + ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...), + args...) if isempty(ys_and_backs) ys_and_backs, _ -> (NoTangent(), NoTangent()) else ys, backs = Zygote.unzip(ys_and_backs) - ys, function ∇responsible_map_internal(Δ) + ys, + function ∇responsible_map_internal(Δ) # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful. - Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ), Zygote._tryreverse(SciMLBase.responsible_map, backs, Δ)...) - Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map, Δf_and_args_zipped)) + Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ), + Zygote._tryreverse(SciMLBase.responsible_map, + backs, Δ)...) + Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map, + Δf_and_args_zipped)) Δf = reduce(Zygote.accum, Δf_and_args[1]) (Δf, Δf_and_args[2:end]...) end end end -ZygoteRules.@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray,Tuple}...) +ZygoteRules.@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...) ∇tmap(__context__, f, args...) end -ZygoteRules.@adjoint function SciMLBase.responsible_map(f, args::Union{AbstractArray,Tuple}...) +ZygoteRules.@adjoint function SciMLBase.responsible_map(f, + args::Union{AbstractArray, Tuple + }...) ∇responsible_map(__context__, f, args...) -end \ No newline at end of file +end diff --git a/test/HybridNODE.jl b/test/HybridNODE.jl index a88a227fb..f6ecf2609 100644 --- a/test/HybridNODE.jl +++ b/test/HybridNODE.jl @@ -5,36 +5,36 @@ using Zygote function test_hybridNODE(sensealg) Random.seed!(12345) datalength = 100 - tspan = (0.0,100.0) - t = range(tspan[1],tspan[2],length=datalength) - target = 3.0*(1:datalength)./datalength # some dummy data to fit to + tspan = (0.0, 100.0) + t = range(tspan[1], tspan[2], length = datalength) + target = 3.0 * (1:datalength) ./ datalength # some dummy data to fit to cbinput = rand(1, datalength) #some external ODE contribution - pmodel = Chain( - Dense(2, 10, init=zeros), - Dense(10, 2, init=zeros)) + pmodel = Chain(Dense(2, 10, init = zeros), + Dense(10, 2, init = zeros)) p, re = Flux.destructure(pmodel) - dudt(u,p,t) = re(p)(u) + dudt(u, p, t) = re(p)(u) # callback changes the first component of the solution every time # t is an integer function affect!(integrator, cbinput) - event_index = round(Int,integrator.t) - integrator.u[1] += 0.2*cbinput[event_index] + event_index = round(Int, integrator.t) + integrator.u[1] += 0.2 * cbinput[event_index] end - callback = PresetTimeCallback(collect(1:datalength),(int)->affect!(int, cbinput)) + callback = PresetTimeCallback(collect(1:datalength), (int) -> affect!(int, cbinput)) # ODE with Callback - prob = ODEProblem(dudt,[0.0, 1.0],tspan,p) + prob = ODEProblem(dudt, [0.0, 1.0], tspan, p) function predict_n_ode(p) arr = Array(solve(prob, Tsit5(), - p=p, sensealg=sensealg, saveat=2.0, callback=callback))[1,2:2:end] + p = p, sensealg = sensealg, saveat = 2.0, callback = callback))[1, + 2:2:end] return arr[1:datalength] end function loss_n_ode() pred = predict_n_ode(p) - loss = sum(abs2,target .- pred)./datalength + loss = sum(abs2, target .- pred) ./ datalength end cb = function () #callback function to observe training @@ -42,60 +42,64 @@ function test_hybridNODE(sensealg) display(loss_n_ode()) end @show sensealg - Flux.train!(loss_n_ode, Flux.params(p), Iterators.repeated((), 20), ADAM(0.005), cb = cb) + Flux.train!(loss_n_ode, Flux.params(p), Iterators.repeated((), 20), ADAM(0.005), + cb = cb) @test loss_n_ode() < 0.5 println(" ") end function test_hybridNODE2(sensealg) Random.seed!(12345) - u0 = Float32[2.; 0.; 0.; 0.] - tspan = (0f0,1f0) + u0 = Float32[2.0; 0.0; 0.0; 0.0] + tspan = (0.0f0, 1.0f0) ## Get goal data function trueaffect!(integrator) - integrator.u[3:4] = -3*integrator.u[1:2] + integrator.u[3:4] = -3 * integrator.u[1:2] end - function trueODEfunc(dx,x,p,t) + function trueODEfunc(dx, x, p, t) @views dx[1:2] .= x[1:2] + x[3:4] dx[1] += x[2] dx[2] += x[1] - dx[3:4] .= 0f0 + dx[3:4] .= 0.0f0 end - cb_ = PeriodicCallback(trueaffect!,0.1f0,save_positions=(true,true),initial_affect=true) - prob = ODEProblem(trueODEfunc,u0,tspan) - sol = solve(prob,Tsit5(),callback=cb_,save_everystep=false,save_start=true) - ode_data = Array(sol)[1:2,1:end]' + cb_ = PeriodicCallback(trueaffect!, 0.1f0, save_positions = (true, true), + initial_affect = true) + prob = ODEProblem(trueODEfunc, u0, tspan) + sol = solve(prob, Tsit5(), callback = cb_, save_everystep = false, save_start = true) + ode_data = Array(sol)[1:2, 1:end]' ## Make model - dudt2 = Chain(Dense(4,50,tanh), - Dense(50,2)) - p,re = Flux.destructure(dudt2) # use this p as the initial condition! + dudt2 = Chain(Dense(4, 50, tanh), + Dense(50, 2)) + p, re = Flux.destructure(dudt2) # use this p as the initial condition! function affect!(integrator) - integrator.u[3:4] = -3*integrator.u[1:2] + integrator.u[3:4] = -3 * integrator.u[1:2] end - function ODEfunc(dx,x,p,t) + function ODEfunc(dx, x, p, t) dx[1:2] .= re(p)(x) - dx[3:4] .= 0f0 + dx[3:4] .= 0.0f0 end z0 = u0 - prob = ODEProblem(ODEfunc,z0,tspan) - cb = PeriodicCallback(affect!,0.1f0,save_positions=(true,true),initial_affect=true) + prob = ODEProblem(ODEfunc, z0, tspan) + cb = PeriodicCallback(affect!, 0.1f0, save_positions = (true, true), + initial_affect = true) ## Initialize learning functions function predict_n_ode() - _prob = remake(prob,p=p) - Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cb,save_everystep=false,save_start=true,sensealg=sensealg))[1:2,:] + _prob = remake(prob, p = p) + Array(solve(_prob, Tsit5(), u0 = z0, p = p, callback = cb, save_everystep = false, + save_start = true, sensealg = sensealg))[1:2, :] end function loss_n_ode() - pred = predict_n_ode()[1:2,1:end]' - loss = sum(abs2,ode_data .- pred) + pred = predict_n_ode()[1:2, 1:end]' + loss = sum(abs2, ode_data .- pred) loss end loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE cba = function () #callback function to observe training - pred = predict_n_ode()[1:2,1:end]' - display(sum(abs2,ode_data .- pred)) + pred = predict_n_ode()[1:2, 1:end]' + display(sum(abs2, ode_data .- pred)) return false end cba() @@ -114,79 +118,81 @@ function test_hybridNODE2(sensealg) end mutable struct Affect{T} - callback_data::T + callback_data::T end -compute_index(t) = round(Int,t)+1 +compute_index(t) = round(Int, t) + 1 function (cb::Affect)(integrator) indx = compute_index(integrator.t) - integrator.u .= integrator.u .+ @view(cb.callback_data[:, indx, 1]) * (integrator.t - integrator.tprev) + integrator.u .= integrator.u .+ + @view(cb.callback_data[:, indx, 1]) * (integrator.t - integrator.tprev) end function test_hybridNODE3(sensealg) - u0 = Float32[2.; 0.] + u0 = Float32[2.0; 0.0] datasize = 101 - tspan = (0.0f0,10.0f0) - - function trueODEfunc(du,u,p,t) + tspan = (0.0f0, 10.0f0) + + function trueODEfunc(du, u, p, t) du .= -u end - t = range(tspan[1],tspan[2],length=datasize) - prob = ODEProblem(trueODEfunc,u0,tspan) - ode_data = Array(solve(prob,Tsit5(),saveat=t)) - - true_data = reshape(ode_data,(2,length(t),1)) - true_data = convert.(Float32,true_data) - callback_data = true_data * 1f-3 - train_dataloader = Flux.Data.DataLoader((true_data = true_data,callback_data = callback_data),batchsize=1) - dudt2 = Chain(Dense(2,50,tanh), - Dense(50,2)) - p,re = Flux.destructure(dudt2) - function dudt(du,u,p,t) - du .= re(p)(u) - end - z0 = Float32[2.; 0.] - prob = ODEProblem(dudt,z0,tspan) - - + t = range(tspan[1], tspan[2], length = datasize) + prob = ODEProblem(trueODEfunc, u0, tspan) + ode_data = Array(solve(prob, Tsit5(), saveat = t)) + + true_data = reshape(ode_data, (2, length(t), 1)) + true_data = convert.(Float32, true_data) + callback_data = true_data * 1.0f-3 + train_dataloader = Flux.Data.DataLoader((true_data = true_data, + callback_data = callback_data), batchsize = 1) + dudt2 = Chain(Dense(2, 50, tanh), + Dense(50, 2)) + p, re = Flux.destructure(dudt2) + function dudt(du, u, p, t) + du .= re(p)(u) + end + z0 = Float32[2.0; 0.0] + prob = ODEProblem(dudt, z0, tspan) + function callback_(callback_data) affect! = Affect(callback_data) - condition(u,t,integrator) = integrator.t > 0 - DiscreteCallback(condition,affect!,save_positions=(false,false)) + condition(u, t, integrator) = integrator.t > 0 + DiscreteCallback(condition, affect!, save_positions = (false, false)) end - - function predict_n_ode(true_data_0,callback_data, sense) - _prob = remake(prob,p=p,u0=true_data_0) - solve(_prob,Tsit5(),callback=callback_(callback_data),saveat=t,sensealg=sense) + + function predict_n_ode(true_data_0, callback_data, sense) + _prob = remake(prob, p = p, u0 = true_data_0) + solve(_prob, Tsit5(), callback = callback_(callback_data), saveat = t, + sensealg = sense) end - - function loss_n_ode(true_data,callback_data) - sol = predict_n_ode((vec(true_data[:,1,:])),callback_data,sensealg) + + function loss_n_ode(true_data, callback_data) + sol = predict_n_ode((vec(true_data[:, 1, :])), callback_data, sensealg) pred = Array(sol) - loss = Flux.mse((true_data[:,:,1]),pred) + loss = Flux.mse((true_data[:, :, 1]), pred) loss end ps = Flux.params(p) opt = ADAM(0.1) epochs = 10 - function cb1(true_data,callback_data) - display(loss_n_ode(true_data,callback_data)) + function cb1(true_data, callback_data) + display(loss_n_ode(true_data, callback_data)) return false end function train!(loss, ps, data, opt, cb) ps = Params(ps) - for (true_data,callback_data) in data + for (true_data, callback_data) in data gs = gradient(ps) do - loss(true_data,callback_data) + loss(true_data, callback_data) end - Flux.update!(opt, ps, gs) - cb(true_data,callback_data) + Flux.update!(opt, ps, gs) + cb(true_data, callback_data) end return nothing end - @Flux.epochs epochs train!(loss_n_ode, Params(ps),train_dataloader, opt, cb1) - loss = loss_n_ode(true_data[:,:,1],callback_data) + Flux.@epochs epochs train!(loss_n_ode, Params(ps), train_dataloader, opt, cb1) + loss = loss_n_ode(true_data[:, :, 1], callback_data) @info loss @test loss < 0.5 end @@ -211,4 +217,3 @@ end test_hybridNODE3(InterpolatingAdjoint()) test_hybridNODE3(QuadratureAdjoint()) end - diff --git a/test/adjoint.jl b/test/adjoint.jl index 5f649f32c..03e596ee5 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -2,118 +2,140 @@ using SciMLSensitivity, OrdinaryDiffEq, RecursiveArrayTools, DiffEqBase, ForwardDiff, Calculus, QuadGK, LinearAlgebra, Zygote using Test -function fb(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]*t - du[2] = dy = -p[3]*u[2] + t*p[4]*u[1]*u[2] +function fb(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] * t + du[2] = dy = -p[3] * u[2] + t * p[4] * u[1] * u[2] end -function foop(u,p,t) - dx = p[1]*u[1] - p[2]*u[1]*u[2]*t - dy = -p[3]*u[2] + t*p[4]*u[1]*u[2] - [dx,dy] +function foop(u, p, t) + dx = p[1] * u[1] - p[2] * u[1] * u[2] * t + dy = -p[3] * u[2] + t * p[4] * u[1] * u[2] + [dx, dy] end -function jac(J,u,p,t) - (x, y, a, b, c, d) = (u[1], u[2], p[1], p[2], p[3], p[4]) - J[1,1] = a + y * b * -1 * t - J[2,1] = t * y * d - J[1,2] = b * x * -1 * t - J[2,2] = c * -1 + t * x * d +function jac(J, u, p, t) + (x, y, a, b, c, d) = (u[1], u[2], p[1], p[2], p[3], p[4]) + J[1, 1] = a + y * b * -1 * t + J[2, 1] = t * y * d + J[1, 2] = b * x * -1 * t + J[2, 2] = c * -1 + t * x * d end -f = ODEFunction(fb,jac=jac) -p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0] -prob = ODEProblem(f,u0,(0.0,10.0),p) -sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14) -probb = ODEProblem(fb,u0,(0.0,10.0),p) -proboop = ODEProblem(foop,u0,(0.0,10.0),p) +f = ODEFunction(fb, jac = jac) +p = [1.5, 1.0, 3.0, 1.0]; +u0 = [1.0; 1.0]; +prob = ODEProblem(f, u0, (0.0, 10.0), p) +sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) +probb = ODEProblem(fb, u0, (0.0, 10.0), p) +proboop = ODEProblem(foop, u0, (0.0, 10.0), p) -solb = solve(probb,Tsit5(),abstol=1e-14,reltol=1e-14) -sol_end = solve(probb,Tsit5(),abstol=1e-14,reltol=1e-14, - save_everystep=false,save_start=false) +solb = solve(probb, Tsit5(), abstol = 1e-14, reltol = 1e-14) +sol_end = solve(probb, Tsit5(), abstol = 1e-14, reltol = 1e-14, + save_everystep = false, save_start = false) -sol_nodense = solve(probb,Tsit5(),abstol=1e-14,reltol=1e-14,dense=false) -soloop = solve(proboop,Tsit5(),abstol=1e-14,reltol=1e-14) -soloop_nodense = solve(proboop,Tsit5(),abstol=1e-14,reltol=1e-14,dense=false) +sol_nodense = solve(probb, Tsit5(), abstol = 1e-14, reltol = 1e-14, dense = false) +soloop = solve(proboop, Tsit5(), abstol = 1e-14, reltol = 1e-14) +soloop_nodense = solve(proboop, Tsit5(), abstol = 1e-14, reltol = 1e-14, dense = false) # Do a discrete adjoint problem println("Calculate discrete adjoint sensitivities") t = 0.0:0.5:10.0 # g(t,u,i) = (1-u)^2/2, L2 away from 1 -function dg(out,u,p,t,i) - (out.=-2.0.+u) +function dg(out, u, p, t, i) + (out .= -2.0 .+ u) end -_,easy_res = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14) -_,easy_res2 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=QuadratureAdjoint(abstol=1e-14,reltol=1e-14)) -_,easy_res22 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=QuadratureAdjoint(autojacvec=false,abstol=1e-14,reltol=1e-14)) -_,easy_res23 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=QuadratureAdjoint(abstol=1e-14,reltol=1e-14,autojacvec=ReverseDiffVJP(true))) -_,easy_res3 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint()) -_,easy_res32 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=false)) -_,easy_res4 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=BacksolveAdjoint()) -_,easy_res42 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=BacksolveAdjoint(autojacvec=false)) -_,easy_res43 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=BacksolveAdjoint(autojacvec=false,checkpointing=false)) -_,easy_res5 = adjoint_sensitivities(sol,Kvaerno5(nlsolve=NLAnderson(), smooth_est=false), - t=t,dg_discrete=dg,abstol=1e-12, - reltol=1e-10, - sensealg=BacksolveAdjoint()) -_,easy_res6 = adjoint_sensitivities(sol_nodense,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(checkpointing=true), - checkpoints=sol.t[1:500:end]) -_,easy_res62 = adjoint_sensitivities(sol_nodense,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=false), - checkpoints=sol.t[1:500:end]) +_, easy_res = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14) +_, easy_res2 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14)) +_, easy_res22 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(autojacvec = false, + abstol = 1e-14, + reltol = 1e-14)) +_, easy_res23 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14, + autojacvec = ReverseDiffVJP(true))) +_, easy_res3 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint()) +_, easy_res32 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = false)) +_, easy_res4 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14, + sensealg = BacksolveAdjoint()) +_, easy_res42 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = BacksolveAdjoint(autojacvec = false)) +_, easy_res43 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = BacksolveAdjoint(autojacvec = false, + checkpointing = false)) +_, easy_res5 = adjoint_sensitivities(sol, + Kvaerno5(nlsolve = NLAnderson(), smooth_est = false), + t = t, dg_discrete = dg, abstol = 1e-12, + reltol = 1e-10, + sensealg = BacksolveAdjoint()) +_, easy_res6 = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(checkpointing = true), + checkpoints = sol.t[1:500:end]) +_, easy_res62 = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = false), + checkpoints = sol.t[1:500:end]) # It should automatically be checkpointing since the solution isn't dense -_,easy_res7 = adjoint_sensitivities(sol_nodense,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(), - checkpoints=sol.t[1:500:end]) - -_,easy_res8 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.TrackerVJP())) -_,easy_res9 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.ZygoteVJP())) -_,easy_res10 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.ReverseDiffVJP()) - ) -_,easy_res11 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.ReverseDiffVJP(true)) - ) -_,easy_res12 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.EnzymeVJP()) - ) -_,easy_res13 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=QuadratureAdjoint(autojacvec=SciMLSensitivity.EnzymeVJP()) - ) - -adj_prob = ODEAdjointProblem(sol,QuadratureAdjoint(abstol=1e-14,reltol=1e-14,autojacvec=SciMLSensitivity.ReverseDiffVJP()),t,dg) -adj_sol = solve(adj_prob,Tsit5(),abstol=1e-14,reltol=1e-14) -integrand = AdjointSensitivityIntegrand(sol,adj_sol,QuadratureAdjoint(abstol=1e-14,reltol=1e-14,autojacvec=SciMLSensitivity.ReverseDiffVJP())) -res,err = quadgk(integrand,0.0,10.0,atol=1e-14,rtol=1e-12) +_, easy_res7 = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(), + checkpoints = sol.t[1:500:end]) + +_, easy_res8 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.TrackerVJP())) +_, easy_res9 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP())) +_, easy_res10 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.ReverseDiffVJP())) +_, easy_res11 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.ReverseDiffVJP(true))) +_, easy_res12 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +_, easy_res13 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) + +adj_prob = ODEAdjointProblem(sol, + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = SciMLSensitivity.ReverseDiffVJP()), + t, dg) +adj_sol = solve(adj_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) +integrand = AdjointSensitivityIntegrand(sol, adj_sol, + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = SciMLSensitivity.ReverseDiffVJP())) +res, err = quadgk(integrand, 0.0, 10.0, atol = 1e-14, rtol = 1e-12) @test isapprox(res, easy_res, rtol = 1e-10) @test isapprox(res, easy_res2, rtol = 1e-10) @@ -137,56 +159,78 @@ res,err = quadgk(integrand,0.0,10.0,atol=1e-14,rtol=1e-12) println("OOP adjoint sensitivities ") -_,easy_res = adjoint_sensitivities(soloop,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14) -_,easy_res2 = adjoint_sensitivities(soloop,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=QuadratureAdjoint(abstol=1e-14,reltol=1e-14)) -@test_broken easy_res22 = adjoint_sensitivities(soloop,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=QuadratureAdjoint(autojacvec=false,abstol=1e-14,reltol=1e-14))[1] isa AbstractArray -_,easy_res2 = adjoint_sensitivities(soloop,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=QuadratureAdjoint(abstol=1e-14,reltol=1e-14,autojacvec=ReverseDiffVJP(true))) -_,easy_res3 = adjoint_sensitivities(soloop,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint()) -@test easy_res32 = adjoint_sensitivities(soloop,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=false))[1] isa AbstractArray -_,easy_res4 = adjoint_sensitivities(soloop,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=BacksolveAdjoint()) -@test easy_res42 = adjoint_sensitivities(soloop,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=BacksolveAdjoint(autojacvec=false))[1] isa AbstractArray -_,easy_res5 = adjoint_sensitivities(soloop,Kvaerno5(nlsolve=NLAnderson(), smooth_est=false), - t=t,dg_discrete=dg,abstol=1e-12, - reltol=1e-10, - sensealg=BacksolveAdjoint()) -_,easy_res6 = adjoint_sensitivities(soloop_nodense,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(checkpointing=true), - checkpoints=soloop_nodense.t[1:5:end]) -@test_broken easy_res62 = adjoint_sensitivities(soloop_nodense,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=false), - checkpoints=soloop_nodense.t[1:5:end]) - -_,easy_res8 = adjoint_sensitivities(soloop_nodense,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.TrackerVJP())) -_,easy_res9 = adjoint_sensitivities(soloop_nodense,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.ZygoteVJP())) -_,easy_res10 = adjoint_sensitivities(soloop_nodense,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.ReverseDiffVJP()) - ) -_,easy_res11 = adjoint_sensitivities(soloop_nodense,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.ReverseDiffVJP(true)) - ) +_, easy_res = adjoint_sensitivities(soloop, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14) +_, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14)) +@test_broken easy_res22 = adjoint_sensitivities(soloop, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(autojacvec = false, + abstol = 1e-14, + reltol = 1e-14))[1] isa + AbstractArray +_, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14, + autojacvec = ReverseDiffVJP(true))) +_, easy_res3 = adjoint_sensitivities(soloop, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint()) +@test easy_res32 = adjoint_sensitivities(soloop, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = false))[1] isa + AbstractArray +_, easy_res4 = adjoint_sensitivities(soloop, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = BacksolveAdjoint()) +@test easy_res42 = adjoint_sensitivities(soloop, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = BacksolveAdjoint(autojacvec = false))[1] isa + AbstractArray +_, easy_res5 = adjoint_sensitivities(soloop, + Kvaerno5(nlsolve = NLAnderson(), smooth_est = false), + t = t, dg_discrete = dg, abstol = 1e-12, + reltol = 1e-10, + sensealg = BacksolveAdjoint()) +_, easy_res6 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(checkpointing = true), + checkpoints = soloop_nodense.t[1:5:end]) +@test_broken easy_res62 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, + dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = false), + checkpoints = soloop_nodense.t[1:5:end]) + +_, easy_res8 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.TrackerVJP())) +_, easy_res9 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP())) +_, easy_res10 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.ReverseDiffVJP())) +_, easy_res11 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.ReverseDiffVJP(true))) #@test_broken _,easy_res12 = adjoint_sensitivities(soloop_nodense,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, # reltol=1e-14, # sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.EnzymeVJP()) @@ -216,35 +260,38 @@ _,easy_res11 = adjoint_sensitivities(soloop_nodense,Tsit5(),t=t,dg_discrete=dg,a println("Calculate adjoint sensitivities ") -_,easy_res8 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - save_everystep=false,save_start=false, - sensealg=BacksolveAdjoint()) -_,easy_res82 = adjoint_sensitivities(solb,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - save_everystep=false,save_start=false, - sensealg=BacksolveAdjoint(checkpointing=false)) +_, easy_res8 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14, + save_everystep = false, save_start = false, + sensealg = BacksolveAdjoint()) +_, easy_res82 = adjoint_sensitivities(solb, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + save_everystep = false, save_start = false, + sensealg = BacksolveAdjoint(checkpointing = false)) @test isapprox(res, easy_res8, rtol = 1e-9) @test isapprox(res, easy_res82, rtol = 1e-9) -_,end_only_res = adjoint_sensitivities(sol_end,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14, - save_everystep=false,save_start=false, - sensealg=BacksolveAdjoint()) +_, end_only_res = adjoint_sensitivities(sol_end, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + save_everystep = false, save_start = false, + sensealg = BacksolveAdjoint()) @test isapprox(res, end_only_res, rtol = 1e-9) println("Calculate adjoint sensitivities from autodiff & numerical diff") function G(p) - tmp_prob = remake(prob,u0=convert.(eltype(p),prob.u0),p=p) - sol = solve(tmp_prob,Tsit5(),abstol=1e-14,reltol=1e-14,sensealg=DiffEqBase.SensitivityADPassThrough(),saveat=t) - A = Array(sol) - sum(((2 .- A).^2)./2) + tmp_prob = remake(prob, u0 = convert.(eltype(p), prob.u0), p = p) + sol = solve(tmp_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14, + sensealg = DiffEqBase.SensitivityADPassThrough(), saveat = t) + A = Array(sol) + sum(((2 .- A) .^ 2) ./ 2) end -G([1.5,1.0,3.0,1.0]) -res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0,1.0]) -res3 = Calculus.gradient(G,[1.5,1.0,3.0,1.0]) +G([1.5, 1.0, 3.0, 1.0]) +res2 = ForwardDiff.gradient(G, [1.5, 1.0, 3.0, 1.0]) +res3 = Calculus.gradient(G, [1.5, 1.0, 3.0, 1.0]) @test norm(res' .- res2) < 1e-7 @test norm(res' .- res3) < 1e-5 @@ -255,22 +302,23 @@ t2 = [0.5, 1.0] t3 = [0.0, 0.5, 1.0] t4 = [0.5, 1.0, 10.0] -_,easy_res2 = adjoint_sensitivities(sol,Tsit5(),t=t2,dg_discrete=dg,abstol=1e-14, - reltol=1e-14) -_,easy_res3 = adjoint_sensitivities(sol,Tsit5(),t=t3,dg_discrete=dg,abstol=1e-14, - reltol=1e-14) -_,easy_res4 = adjoint_sensitivities(sol,Tsit5(),t=t4,dg_discrete=dg,abstol=1e-14, - reltol=1e-14) - -function G(p,ts) - tmp_prob = remake(prob,u0=convert.(eltype(p),prob.u0),p=p) - sol = solve(tmp_prob,Tsit5(),abstol=1e-10,reltol=1e-10,sensealg=DiffEqBase.SensitivityADPassThrough(),saveat=ts) - A = convert(Array,sol) - sum(((2 .- A).^2)./2) +_, easy_res2 = adjoint_sensitivities(sol, Tsit5(), t = t2, dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14) +_, easy_res3 = adjoint_sensitivities(sol, Tsit5(), t = t3, dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14) +_, easy_res4 = adjoint_sensitivities(sol, Tsit5(), t = t4, dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14) + +function G(p, ts) + tmp_prob = remake(prob, u0 = convert.(eltype(p), prob.u0), p = p) + sol = solve(tmp_prob, Tsit5(), abstol = 1e-10, reltol = 1e-10, + sensealg = DiffEqBase.SensitivityADPassThrough(), saveat = ts) + A = convert(Array, sol) + sum(((2 .- A) .^ 2) ./ 2) end -res2 = ForwardDiff.gradient(p->G(p,t2),[1.5,1.0,3.0,1.0]) -res3 = ForwardDiff.gradient(p->G(p,t3),[1.5,1.0,3.0,1.0]) -res4 = ForwardDiff.gradient(p->G(p,t4),[1.5,1.0,3.0,1.0]) +res2 = ForwardDiff.gradient(p -> G(p, t2), [1.5, 1.0, 3.0, 1.0]) +res3 = ForwardDiff.gradient(p -> G(p, t3), [1.5, 1.0, 3.0, 1.0]) +res4 = ForwardDiff.gradient(p -> G(p, t4), [1.5, 1.0, 3.0, 1.0]) @test easy_res2' ≈ res2 @test easy_res3' ≈ res3 @@ -278,191 +326,226 @@ res4 = ForwardDiff.gradient(p->G(p,t4),[1.5,1.0,3.0,1.0]) println("Adjoints of u0") -function dg(out,u,p,t,i) - out .= -1 .+ u +function dg(out, u, p, t, i) + out .= -1 .+ u end -ū0,adj = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14) - -_,adjnou0 = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14) - -ū02,adj2 = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - sensealg=BacksolveAdjoint(), - reltol=1e-14) - -ū022,adj22 = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - sensealg=BacksolveAdjoint(autojacvec=false), - reltol=1e-14) - -ū023,adj23 = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - sensealg=BacksolveAdjoint(autojacvec=false,checkpointing=false), - reltol=1e-14) - -ū03,adj3 = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - sensealg=InterpolatingAdjoint(), - reltol=1e-14) - -ū032,adj32 = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=false), - reltol=1e-14) - -ū04,adj4 = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - sensealg=InterpolatingAdjoint(checkpointing=true), - checkpoints=sol.t[1:500:end], - reltol=1e-14) - -@test_nowarn adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - sensealg=InterpolatingAdjoint(checkpointing=true), - checkpoints=sol.t[1:5:end], - reltol=1e-14) - -ū042,adj42 = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=false), - checkpoints=sol.t[1:500:end], - reltol=1e-14) - -ū05,adj5 = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - sensealg=QuadratureAdjoint(abstol=1e-14,reltol=1e-14), - reltol=1e-14) - -ū052,adj52 = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - sensealg=QuadratureAdjoint(autojacvec=false,abstol=1e-14,reltol=1e-14), - reltol=1e-14) - -ū05,adj53 = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - sensealg=QuadratureAdjoint(abstol=1e-14,reltol=1e-14,autojacvec=ReverseDiffVJP(true)), - reltol=1e-14) - -ū0args,adjargs = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - save_everystep=false, save_start=false, - sensealg=BacksolveAdjoint(), - reltol=1e-14) - -ū0args2,adjargs2 = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - save_everystep=false, save_start=false, - sensealg=InterpolatingAdjoint(), - reltol=1e-14) +ū0, adj = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14) + +_, adjnou0 = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + reltol = 1e-14) + +ū02, adj2 = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + sensealg = BacksolveAdjoint(), + reltol = 1e-14) + +ū022, adj22 = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + sensealg = BacksolveAdjoint(autojacvec = false), + reltol = 1e-14) + +ū023, adj23 = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + sensealg = BacksolveAdjoint(autojacvec = false, + checkpointing = false), + reltol = 1e-14) + +ū03, adj3 = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + sensealg = InterpolatingAdjoint(), + reltol = 1e-14) + +ū032, adj32 = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = false), + reltol = 1e-14) + +ū04, adj4 = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + sensealg = InterpolatingAdjoint(checkpointing = true), + checkpoints = sol.t[1:500:end], + reltol = 1e-14) + +@test_nowarn adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + sensealg = InterpolatingAdjoint(checkpointing = true), + checkpoints = sol.t[1:5:end], + reltol = 1e-14) + +ū042, adj42 = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = false), + checkpoints = sol.t[1:500:end], + reltol = 1e-14) + +ū05, adj5 = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14), + reltol = 1e-14) + +ū052, adj52 = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + sensealg = QuadratureAdjoint(autojacvec = false, + abstol = 1e-14, + reltol = 1e-14), + reltol = 1e-14) + +ū05, adj53 = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, abstol = 1e-14, + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14, + autojacvec = ReverseDiffVJP(true)), + reltol = 1e-14) + +ū0args, adjargs = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + save_everystep = false, save_start = false, + sensealg = BacksolveAdjoint(), + reltol = 1e-14) + +ū0args2, adjargs2 = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + save_everystep = false, save_start = false, + sensealg = InterpolatingAdjoint(), + reltol = 1e-14) res = ForwardDiff.gradient(prob.u0) do u0 - tmp_prob = remake(prob,u0=u0) - sol = solve(tmp_prob,Tsit5(),abstol=1e-14,reltol=1e-14,saveat=t) - A = convert(Array,sol) - sum(((1 .- A).^2)./2) + tmp_prob = remake(prob, u0 = u0) + sol = solve(tmp_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14, saveat = t) + A = convert(Array, sol) + sum(((1 .- A) .^ 2) ./ 2) end -@test ū0 ≈ res rtol = 1e-10 -@test ū02 ≈ res rtol = 1e-10 -@test ū022 ≈ res rtol = 1e-10 -@test ū023 ≈ res rtol = 1e-10 -@test ū03 ≈ res rtol = 1e-10 -@test ū032 ≈ res rtol = 1e-10 -@test ū04 ≈ res rtol = 1e-10 -@test ū042 ≈ res rtol = 1e-10 -@test ū05 ≈ res rtol = 1e-10 -@test ū052 ≈ res rtol = 1e-10 -@test adj ≈ adjnou0 rtol = 1e-10 -@test adj ≈ adj2 rtol = 1e-10 -@test adj ≈ adj22 rtol = 1e-10 -@test adj ≈ adj23 rtol = 1e-10 -@test adj ≈ adj3 rtol = 1e-10 -@test adj ≈ adj32 rtol = 1e-10 -@test adj ≈ adj4 rtol = 1e-10 -@test adj ≈ adj42 rtol = 1e-10 -@test adj ≈ adj5 rtol = 1e-10 -@test adj ≈ adj52 rtol = 1e-10 -@test adj ≈ adj53 rtol = 1e-10 - -@test ū0args ≈ res rtol = 1e-10 -@test adjargs ≈ adj rtol = 1e-10 -@test ū0args2 ≈ res rtol = 1e-10 -@test adjargs2 ≈ adj rtol = 1e-10 +@test ū0≈res rtol=1e-10 +@test ū02≈res rtol=1e-10 +@test ū022≈res rtol=1e-10 +@test ū023≈res rtol=1e-10 +@test ū03≈res rtol=1e-10 +@test ū032≈res rtol=1e-10 +@test ū04≈res rtol=1e-10 +@test ū042≈res rtol=1e-10 +@test ū05≈res rtol=1e-10 +@test ū052≈res rtol=1e-10 +@test adj≈adjnou0 rtol=1e-10 +@test adj≈adj2 rtol=1e-10 +@test adj≈adj22 rtol=1e-10 +@test adj≈adj23 rtol=1e-10 +@test adj≈adj3 rtol=1e-10 +@test adj≈adj32 rtol=1e-10 +@test adj≈adj4 rtol=1e-10 +@test adj≈adj42 rtol=1e-10 +@test adj≈adj5 rtol=1e-10 +@test adj≈adj52 rtol=1e-10 +@test adj≈adj53 rtol=1e-10 + +@test ū0args≈res rtol=1e-10 +@test adjargs≈adj rtol=1e-10 +@test ū0args2≈res rtol=1e-10 +@test adjargs2≈adj rtol=1e-10 println("Do a continuous adjoint problem") # Energy calculation -g(u,p,t) = (sum(u).^2) ./ 2 +g(u, p, t) = (sum(u) .^ 2) ./ 2 # Gradient of (u1 + u2)^2 / 2 -function dg(out,u,p,t) - out[1]= u[1] + u[2] - out[2]= u[1] + u[2] +function dg(out, u, p, t) + out[1] = u[1] + u[2] + out[2] = u[1] + u[2] end -adj_prob = ODEAdjointProblem(sol,QuadratureAdjoint(abstol=1e-14,reltol=1e-14,autojacvec=SciMLSensitivity.ReverseDiffVJP()),nothing,nothing,dg,g) -adj_sol = solve(adj_prob,Tsit5(),abstol=1e-14,reltol=1e-10) -integrand = AdjointSensitivityIntegrand(sol,adj_sol,QuadratureAdjoint(abstol=1e-14,reltol=1e-14,autojacvec=SciMLSensitivity.ReverseDiffVJP())) -res,err = quadgk(integrand,0.0,10.0,atol=1e-14,rtol=1e-10) +adj_prob = ODEAdjointProblem(sol, + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = SciMLSensitivity.ReverseDiffVJP()), + nothing, nothing, dg, g) +adj_sol = solve(adj_prob, Tsit5(), abstol = 1e-14, reltol = 1e-10) +integrand = AdjointSensitivityIntegrand(sol, adj_sol, + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = SciMLSensitivity.ReverseDiffVJP())) +res, err = quadgk(integrand, 0.0, 10.0, atol = 1e-14, rtol = 1e-10) println("Test the `adjoint_sensitivities` utility function") -_,easy_res = adjoint_sensitivities(sol,Tsit5(),dg_continuous=dg,g=g,abstol=1e-14, - reltol=1e-14) +_, easy_res = adjoint_sensitivities(sol, Tsit5(), dg_continuous = dg, g = g, abstol = 1e-14, + reltol = 1e-14) println("2") -_,easy_res2 = adjoint_sensitivities(sol,Tsit5(),dg_continuous=dg,g=g,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint()) -_,easy_res22 = adjoint_sensitivities(sol,Tsit5(),dg_continuous=dg,g=g,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=false)) +_, easy_res2 = adjoint_sensitivities(sol, Tsit5(), dg_continuous = dg, g = g, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint()) +_, easy_res22 = adjoint_sensitivities(sol, Tsit5(), dg_continuous = dg, g = g, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = false)) println("23") -_,easy_res23 = adjoint_sensitivities(sol,Tsit5(),dg_continuous=dg,g=g,abstol=1e-14, - reltol=1e-14, - sensealg=QuadratureAdjoint(abstol=1e-14,reltol=1e-14)) -_,easy_res232 = adjoint_sensitivities(sol,Tsit5(),dg_continuous=dg,g=g,abstol=1e-14, - reltol=1e-14, - sensealg=QuadratureAdjoint(abstol=1e-14,reltol=1e-14,autojacvec=ReverseDiffVJP(false))) -_,easy_res24 = adjoint_sensitivities(sol,Tsit5(),dg_continuous=dg,g=g,abstol=1e-14, - reltol=1e-14, - sensealg=QuadratureAdjoint(autojacvec=false,abstol=1e-14,reltol=1e-14)) +_, easy_res23 = adjoint_sensitivities(sol, Tsit5(), dg_continuous = dg, g = g, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14)) +_, easy_res232 = adjoint_sensitivities(sol, Tsit5(), dg_continuous = dg, g = g, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14, + autojacvec = ReverseDiffVJP(false))) +_, easy_res24 = adjoint_sensitivities(sol, Tsit5(), dg_continuous = dg, g = g, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(autojacvec = false, + abstol = 1e-14, + reltol = 1e-14)) println("25") -_,easy_res25 = adjoint_sensitivities(sol,Tsit5(),dg_continuous=dg,g=g,abstol=1e-14, - reltol=1e-14, - sensealg=BacksolveAdjoint()) -_,easy_res26 = adjoint_sensitivities(sol,Tsit5(),dg_continuous=dg,g=g,abstol=1e-14, - reltol=1e-14, - sensealg=BacksolveAdjoint(autojacvec=false)) -_,easy_res262 = adjoint_sensitivities(sol,Tsit5(),dg_continuous=dg,g=g,abstol=1e-14, - reltol=1e-14, - sensealg=BacksolveAdjoint(autojacvec=false,checkpointing=false)) +_, easy_res25 = adjoint_sensitivities(sol, Tsit5(), dg_continuous = dg, g = g, + abstol = 1e-14, + reltol = 1e-14, + sensealg = BacksolveAdjoint()) +_, easy_res26 = adjoint_sensitivities(sol, Tsit5(), dg_continuous = dg, g = g, + abstol = 1e-14, + reltol = 1e-14, + sensealg = BacksolveAdjoint(autojacvec = false)) +_, easy_res262 = adjoint_sensitivities(sol, Tsit5(), dg_continuous = dg, g = g, + abstol = 1e-14, + reltol = 1e-14, + sensealg = BacksolveAdjoint(autojacvec = false, + checkpointing = false)) println("27") -_,easy_res27 = adjoint_sensitivities(sol,Tsit5(),dg_continuous=dg,g=g,abstol=1e-14, - reltol=1e-14, - checkpoints=sol.t[1:500:end], - sensealg=InterpolatingAdjoint(checkpointing=true)) -_,easy_res28 = adjoint_sensitivities(sol,Tsit5(),dg_continuous=dg,g=g,abstol=1e-14, - reltol=1e-14, - checkpoints=sol.t[1:500:end], - sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=false)) +_, easy_res27 = adjoint_sensitivities(sol, Tsit5(), dg_continuous = dg, g = g, + abstol = 1e-14, + reltol = 1e-14, + checkpoints = sol.t[1:500:end], + sensealg = InterpolatingAdjoint(checkpointing = true)) +_, easy_res28 = adjoint_sensitivities(sol, Tsit5(), dg_continuous = dg, g = g, + abstol = 1e-14, + reltol = 1e-14, + checkpoints = sol.t[1:500:end], + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = false)) println("3") -_,easy_res3 = adjoint_sensitivities(sol,Tsit5(),g=g,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint()) -_,easy_res32 = adjoint_sensitivities(sol,Tsit5(),g=g,abstol=1e-14, - reltol=1e-14, - sensealg=InterpolatingAdjoint(autojacvec=false)) +_, easy_res3 = adjoint_sensitivities(sol, Tsit5(), g = g, abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint()) +_, easy_res32 = adjoint_sensitivities(sol, Tsit5(), g = g, abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = false)) println("33") -_,easy_res33 = adjoint_sensitivities(sol,Tsit5(),g=g,abstol=1e-14, - reltol=1e-14, - sensealg=QuadratureAdjoint(abstol=1e-14,reltol=1e-14)) -_,easy_res34 = adjoint_sensitivities(sol,Tsit5(),g=g,abstol=1e-14, - reltol=1e-14, - sensealg=QuadratureAdjoint(autojacvec=false,abstol=1e-14,reltol=1e-14)) +_, easy_res33 = adjoint_sensitivities(sol, Tsit5(), g = g, abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14)) +_, easy_res34 = adjoint_sensitivities(sol, Tsit5(), g = g, abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(autojacvec = false, + abstol = 1e-14, + reltol = 1e-14)) println("35") -_,easy_res35 = adjoint_sensitivities(sol,Tsit5(),g=g,abstol=1e-14, - reltol=1e-14, - sensealg=BacksolveAdjoint()) -_,easy_res36 = adjoint_sensitivities(sol,Tsit5(),g=g,abstol=1e-14, - reltol=1e-14, - sensealg=BacksolveAdjoint(autojacvec=false)) +_, easy_res35 = adjoint_sensitivities(sol, Tsit5(), g = g, abstol = 1e-14, + reltol = 1e-14, + sensealg = BacksolveAdjoint()) +_, easy_res36 = adjoint_sensitivities(sol, Tsit5(), g = g, abstol = 1e-14, + reltol = 1e-14, + sensealg = BacksolveAdjoint(autojacvec = false)) println("37") -_,easy_res37 = adjoint_sensitivities(sol,Tsit5(),g=g,abstol=1e-14, - reltol=1e-14, - checkpoints=sol.t[1:500:end], - sensealg=InterpolatingAdjoint(checkpointing=true)) -_,easy_res38 = adjoint_sensitivities(sol,Tsit5(),g=g,abstol=1e-14, - reltol=1e-14, - checkpoints=sol.t[1:500:end], - sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=false)) +_, easy_res37 = adjoint_sensitivities(sol, Tsit5(), g = g, abstol = 1e-14, + reltol = 1e-14, + checkpoints = sol.t[1:500:end], + sensealg = InterpolatingAdjoint(checkpointing = true)) +_, easy_res38 = adjoint_sensitivities(sol, Tsit5(), g = g, abstol = 1e-14, + reltol = 1e-14, + checkpoints = sol.t[1:500:end], + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = false)) @test norm(easy_res .- res) < 1e-8 @test norm(easy_res2 .- res) < 1e-8 @@ -486,241 +569,283 @@ _,easy_res38 = adjoint_sensitivities(sol,Tsit5(),g=g,abstol=1e-14, println("Calculate adjoint sensitivities from autodiff & numerical diff") function G(p) - tmp_prob = remake(prob,u0=eltype(p).(prob.u0),p=p, - tspan=eltype(p).(prob.tspan)) - sol = solve(tmp_prob,Tsit5(),abstol=1e-14,reltol=1e-14) - res,err = quadgk((t)-> (sum(sol(t)).^2)./2,0.0,10.0,atol=1e-14,rtol=1e-10) - res + tmp_prob = remake(prob, u0 = eltype(p).(prob.u0), p = p, + tspan = eltype(p).(prob.tspan)) + sol = solve(tmp_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) + res, err = quadgk((t) -> (sum(sol(t)) .^ 2) ./ 2, 0.0, 10.0, atol = 1e-14, rtol = 1e-10) + res end -res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0,1.0]) -res3 = Calculus.gradient(G,[1.5,1.0,3.0,1.0]) +res2 = ForwardDiff.gradient(G, [1.5, 1.0, 3.0, 1.0]) +res3 = Calculus.gradient(G, [1.5, 1.0, 3.0, 1.0]) @test norm(res' .- res2) < 1e-8 @test norm(res' .- res3) < 1e-6 # Buffer length test f = (du, u, p, t) -> du .= 0 -p = zeros(3); u = zeros(50) -prob = ODEProblem(f,u,(0.0,10.0),p) -sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14) -@test_nowarn _,res = adjoint_sensitivities(sol,Tsit5(),t=t,dg_discrete=dg,abstol=1e-14, - reltol=1e-14) +p = zeros(3); +u = zeros(50); +prob = ODEProblem(f, u, (0.0, 10.0), p) +sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) +@test_nowarn _, res = adjoint_sensitivities(sol, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-14, + reltol = 1e-14) @testset "Checkpointed backsolve" begin - using SciMLSensitivity, OrdinaryDiffEq - tf = 10.0 - function lorenz(du,u,p,t) - σ, ρ, β = p - du[1] = σ*(u[2]-u[1]) - du[2] = u[1]*(ρ-u[3]) - u[2] - du[3] = u[1]*u[2] - β*u[3] - return nothing - end - prob_lorenz = ODEProblem(lorenz, [1.0, 0.0, 0.0], (0, tf), [10, 28, 8/3]) - sol_lorenz = solve(prob_lorenz,Tsit5(),reltol=1e-6,abstol=1e-9) - function dg(out,u,p,t,i) - (out.=-2.0.+u) - end - t = 0:0.1:tf - _,easy_res1 = adjoint_sensitivities(sol_lorenz,Tsit5(),t=t,dg_discrete=dg,abstol=1e-6, - reltol=1e-9, - sensealg=BacksolveAdjoint()) - _,easy_res2 = adjoint_sensitivities(sol_lorenz,Tsit5(),t=t,dg_discrete=dg,abstol=1e-6, - reltol=1e-9, - sensealg=InterpolatingAdjoint()) - _,easy_res3 = adjoint_sensitivities(sol_lorenz,Tsit5(),t=t,dg_discrete=dg,abstol=1e-6, - reltol=1e-9, - sensealg=BacksolveAdjoint(), - checkpoints=sol_lorenz.t[1:10:end]) - _,easy_res4 = adjoint_sensitivities(sol_lorenz,Tsit5(),t=t,dg_discrete=dg,abstol=1e-6, - reltol=1e-9, - sensealg=BacksolveAdjoint(), - checkpoints=sol_lorenz.t[1:20:end]) - # cannot finish in a reasonable amount of time - @test_skip adjoint_sensitivities(sol_lorenz,Tsit5(),t=t,dg_discrete=dg,abstol=1e-6, - reltol=1e-9, - sensealg=BacksolveAdjoint(checkpointing=false)) - @test easy_res2 ≈ easy_res1 rtol=1e-5 - @test easy_res2 ≈ easy_res3 rtol=1e-5 - @test easy_res2 ≈ easy_res4 rtol=1e-4 - - ū1,adj1 = adjoint_sensitivities(sol_lorenz,Tsit5(),t=t,dg_discrete=dg,abstol=1e-6, - reltol=1e-9, - sensealg=BacksolveAdjoint()) - ū2,adj2 = adjoint_sensitivities(sol_lorenz,Tsit5(),t=t,dg_discrete=dg,abstol=1e-6, - reltol=1e-9, - sensealg=InterpolatingAdjoint()) - ū3,adj3 = adjoint_sensitivities(sol_lorenz,Tsit5(),t=t,dg_discrete=dg,abstol=1e-6, - reltol=1e-9, - sensealg=BacksolveAdjoint(), - checkpoints=sol_lorenz.t[1:10:end]) - ū4,adj4 = adjoint_sensitivities(sol_lorenz,Tsit5(),t=t,dg_discrete=dg,abstol=1e-6, - reltol=1e-9, - sensealg=BacksolveAdjoint(), - checkpoints=sol_lorenz.t[1:20:end]) - # cannot finish in a reasonable amount of time - @test_skip adjoint_sensitivities(sol_lorenz,Tsit5(),t=t,dg_discrete=dg,abstol=1e-6, - reltol=1e-9, - sensealg=BacksolveAdjoint(checkpointing=false)) - @test ū2 ≈ ū1 rtol=1e-5 - @test adj2 ≈ adj1 rtol=1e-5 - @test ū2 ≈ ū3 rtol=1e-5 - @test adj2 ≈ adj3 rtol=1e-5 - @test ū2 ≈ ū4 rtol=1e-4 - @test adj2 ≈ adj4 rtol=1e-4 - - - # LQR Tests from issue https://github.com/SciML/SciMLSensitivity.jl/issues/300 - x_dim = 2 - T = 40.0 - - cost = (x, u) -> x'*x - params = [-0.4142135623730951, 0.0, -0.0, -0.4142135623730951, 0.0, 0.0] - - function dynamics!(du,u,p,t) - du[1] = -u[1] + tanh(p[1]*u[1]+p[2]*u[2]) - du[2] = -u[2] + tanh(p[3]*u[1]+p[4]*u[2]) - end - - function backsolve_grad(sol, lqr_params, checkpointing) - bwd_sol = solve( - ODEAdjointProblem( - sol, - BacksolveAdjoint(autojacvec=EnzymeVJP(),checkpointing = checkpointing), - nothing, nothing, nothing, (x, lqr_params, t) -> cost(x,lqr_params) - ), - Tsit5(), - dense = false, - save_everystep = false, - ) - - bwd_sol.u[end][1:end-x_dim] - #fwd_sol, bwd_sol - end - - - - x0 = ones(x_dim) - fwd_sol = solve( - ODEProblem(dynamics!, x0, (0, T), params), - Tsit5(),abstol=1e-9, reltol=1e-9, - u0 = x0, - p = params, - dense = false, - save_everystep = true - ) - - - - backsolve_results = backsolve_grad(fwd_sol, params, false) - backsolve_checkpointing_results = backsolve_grad(fwd_sol, params, true) - - @test backsolve_results != backsolve_checkpointing_results - - int_u0, int_p = adjoint_sensitivities(fwd_sol,Tsit5(),g=(x, params, t)->cost(x,params), sensealg=InterpolatingAdjoint()) - - @test isapprox(backsolve_checkpointing_results[1:length(x0)], int_u0, rtol=1e-10) - @test isapprox(backsolve_checkpointing_results[(1:length(params)) .+ length(x0)], int_p', rtol=1e-10) + using SciMLSensitivity, OrdinaryDiffEq + tf = 10.0 + function lorenz(du, u, p, t) + σ, ρ, β = p + du[1] = σ * (u[2] - u[1]) + du[2] = u[1] * (ρ - u[3]) - u[2] + du[3] = u[1] * u[2] - β * u[3] + return nothing + end + prob_lorenz = ODEProblem(lorenz, [1.0, 0.0, 0.0], (0, tf), [10, 28, 8 / 3]) + sol_lorenz = solve(prob_lorenz, Tsit5(), reltol = 1e-6, abstol = 1e-9) + function dg(out, u, p, t, i) + (out .= -2.0 .+ u) + end + t = 0:0.1:tf + _, easy_res1 = adjoint_sensitivities(sol_lorenz, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-6, + reltol = 1e-9, + sensealg = BacksolveAdjoint()) + _, easy_res2 = adjoint_sensitivities(sol_lorenz, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-6, + reltol = 1e-9, + sensealg = InterpolatingAdjoint()) + _, easy_res3 = adjoint_sensitivities(sol_lorenz, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-6, + reltol = 1e-9, + sensealg = BacksolveAdjoint(), + checkpoints = sol_lorenz.t[1:10:end]) + _, easy_res4 = adjoint_sensitivities(sol_lorenz, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-6, + reltol = 1e-9, + sensealg = BacksolveAdjoint(), + checkpoints = sol_lorenz.t[1:20:end]) + # cannot finish in a reasonable amount of time + @test_skip adjoint_sensitivities(sol_lorenz, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-6, + reltol = 1e-9, + sensealg = BacksolveAdjoint(checkpointing = false)) + @test easy_res2≈easy_res1 rtol=1e-5 + @test easy_res2≈easy_res3 rtol=1e-5 + @test easy_res2≈easy_res4 rtol=1e-4 + + ū1, adj1 = adjoint_sensitivities(sol_lorenz, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-6, + reltol = 1e-9, + sensealg = BacksolveAdjoint()) + ū2, adj2 = adjoint_sensitivities(sol_lorenz, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-6, + reltol = 1e-9, + sensealg = InterpolatingAdjoint()) + ū3, adj3 = adjoint_sensitivities(sol_lorenz, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-6, + reltol = 1e-9, + sensealg = BacksolveAdjoint(), + checkpoints = sol_lorenz.t[1:10:end]) + ū4, adj4 = adjoint_sensitivities(sol_lorenz, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-6, + reltol = 1e-9, + sensealg = BacksolveAdjoint(), + checkpoints = sol_lorenz.t[1:20:end]) + # cannot finish in a reasonable amount of time + @test_skip adjoint_sensitivities(sol_lorenz, Tsit5(), t = t, dg_discrete = dg, + abstol = 1e-6, + reltol = 1e-9, + sensealg = BacksolveAdjoint(checkpointing = false)) + @test ū2≈ū1 rtol=1e-5 + @test adj2≈adj1 rtol=1e-5 + @test ū2≈ū3 rtol=1e-5 + @test adj2≈adj3 rtol=1e-5 + @test ū2≈ū4 rtol=1e-4 + @test adj2≈adj4 rtol=1e-4 + + # LQR Tests from issue https://github.com/SciML/SciMLSensitivity.jl/issues/300 + x_dim = 2 + T = 40.0 + + cost = (x, u) -> x' * x + params = [-0.4142135623730951, 0.0, -0.0, -0.4142135623730951, 0.0, 0.0] + + function dynamics!(du, u, p, t) + du[1] = -u[1] + tanh(p[1] * u[1] + p[2] * u[2]) + du[2] = -u[2] + tanh(p[3] * u[1] + p[4] * u[2]) + end + + function backsolve_grad(sol, lqr_params, checkpointing) + bwd_sol = solve(ODEAdjointProblem(sol, + BacksolveAdjoint(autojacvec = EnzymeVJP(), + checkpointing = checkpointing), + nothing, nothing, nothing, + (x, lqr_params, t) -> cost(x, lqr_params)), + Tsit5(), + dense = false, + save_everystep = false) + + bwd_sol.u[end][1:(end - x_dim)] + #fwd_sol, bwd_sol + end + + x0 = ones(x_dim) + fwd_sol = solve(ODEProblem(dynamics!, x0, (0, T), params), + Tsit5(), abstol = 1e-9, reltol = 1e-9, + u0 = x0, + p = params, + dense = false, + save_everystep = true) + + backsolve_results = backsolve_grad(fwd_sol, params, false) + backsolve_checkpointing_results = backsolve_grad(fwd_sol, params, true) + + @test backsolve_results != backsolve_checkpointing_results + + int_u0, int_p = adjoint_sensitivities(fwd_sol, Tsit5(), + g = (x, params, t) -> cost(x, params), + sensealg = InterpolatingAdjoint()) + + @test isapprox(backsolve_checkpointing_results[1:length(x0)], int_u0, rtol = 1e-10) + @test isapprox(backsolve_checkpointing_results[(1:length(params)) .+ length(x0)], + int_p', rtol = 1e-10) end using Test using LinearAlgebra, SciMLSensitivity, OrdinaryDiffEq, ForwardDiff, QuadGK @testset "Adjoint of differential algebric equations with mass matrix" begin - function G(p, prob, ts, cost) - tmp_prob_mm = remake(prob,u0=convert.(eltype(p),prob.u0),p=p) - sol = solve(tmp_prob_mm,Rodas5(autodiff=false),abstol=1e-14,reltol=1e-14,saveat=ts) - cost(sol) - end - alg = Rodas5(autodiff=false) - @testset "Fully ranked mass matrix" begin - @info "discrete cost" - A = [1 2 3; 4 5 6; 7 8 9] - function foo(du, u, p, t) - mul!(du, A, u) - du .= du .+ p - du[2] += sum(p) - return nothing + function G(p, prob, ts, cost) + tmp_prob_mm = remake(prob, u0 = convert.(eltype(p), prob.u0), p = p) + sol = solve(tmp_prob_mm, Rodas5(autodiff = false), abstol = 1e-14, reltol = 1e-14, + saveat = ts) + cost(sol) end - mm = -[1 2 4; 2 3 7; 1 3 41] - u0 = [1, 2.0, 3] - p = [1.0, 2.0, 3] - prob_mm = ODEProblem(ODEFunction(foo, mass_matrix=mm), u0, (0, 1.0), p) - sol_mm = solve(prob_mm, Rodas5(), reltol=1e-14, abstol=1e-14) - - ts = 0:0.01:1 - dg(out,u,p,t,i) = out .= 1 - _, res = adjoint_sensitivities(sol_mm,alg,t=ts,dg_discrete=dg,abstol=1e-14,reltol=1e-14,sensealg=QuadratureAdjoint()) - reference_sol = ForwardDiff.gradient(p->G(p, prob_mm, ts, sum),vec(p)) - @test res' ≈ reference_sol rtol=1e-11 - - _, res_interp = adjoint_sensitivities(sol_mm,alg,t=ts,dg_discrete=dg,abstol=1e-14,reltol=1e-14,sensealg=InterpolatingAdjoint()) - @test res_interp ≈ res rtol = 1e-11 - _, res_interp2 = adjoint_sensitivities(sol_mm,alg,t=ts,dg_discrete=dg,abstol=1e-14,reltol=1e-14,sensealg=InterpolatingAdjoint(checkpointing=true),checkpoints=sol_mm.t[1:10:end]) - @test res_interp2 ≈ res rtol = 1e-11 - - _, res_bs = adjoint_sensitivities(sol_mm,alg,t=ts,dg_discrete=dg,abstol=1e-14,reltol=1e-14,sensealg=BacksolveAdjoint(checkpointing=false)) - @test res_bs ≈ res rtol = 1e-11 - _, res_bs2 = adjoint_sensitivities(sol_mm,alg,t=ts,dg_discrete=dg,abstol=1e-14,reltol=1e-14,sensealg=BacksolveAdjoint(checkpointing=true),checkpoints=sol_mm.t) - @test res_bs2 ≈ res rtol = 1e-11 - - @info "continuous cost" - g_cont(u,p,t) = (sum(u).^2) ./ 2 - dg_cont(out,u,p,t) = out .= sum(u) - _,easy_res_cont = adjoint_sensitivities(sol_mm,alg,dg_continuous=dg_cont,g=g_cont, - abstol=1e-10,reltol=1e-10, - sensealg=QuadratureAdjoint()) - function G_cont(p) - tmp_prob_mm = remake(prob_mm,u0=eltype(p).(prob_mm.u0),p=p, - tspan=eltype(p).(prob_mm.tspan)) - sol = solve(tmp_prob_mm,Rodas5(autodiff=false),abstol=1e-14,reltol=1e-14) - res,err = quadgk((t)-> (sum(sol(t)).^2)./2,prob_mm.tspan...,atol=1e-14,rtol=1e-10) - res + alg = Rodas5(autodiff = false) + @testset "Fully ranked mass matrix" begin + @info "discrete cost" + A = [1 2 3; 4 5 6; 7 8 9] + function foo(du, u, p, t) + mul!(du, A, u) + du .= du .+ p + du[2] += sum(p) + return nothing + end + mm = -[1 2 4; 2 3 7; 1 3 41] + u0 = [1, 2.0, 3] + p = [1.0, 2.0, 3] + prob_mm = ODEProblem(ODEFunction(foo, mass_matrix = mm), u0, (0, 1.0), p) + sol_mm = solve(prob_mm, Rodas5(), reltol = 1e-14, abstol = 1e-14) + + ts = 0:0.01:1 + dg(out, u, p, t, i) = out .= 1 + _, res = adjoint_sensitivities(sol_mm, alg, t = ts, dg_discrete = dg, + abstol = 1e-14, reltol = 1e-14, + sensealg = QuadratureAdjoint()) + reference_sol = ForwardDiff.gradient(p -> G(p, prob_mm, ts, sum), vec(p)) + @test res'≈reference_sol rtol=1e-11 + + _, res_interp = adjoint_sensitivities(sol_mm, alg, t = ts, dg_discrete = dg, + abstol = 1e-14, reltol = 1e-14, + sensealg = InterpolatingAdjoint()) + @test res_interp≈res rtol=1e-11 + _, res_interp2 = adjoint_sensitivities(sol_mm, alg, t = ts, dg_discrete = dg, + abstol = 1e-14, reltol = 1e-14, + sensealg = InterpolatingAdjoint(checkpointing = true), + checkpoints = sol_mm.t[1:10:end]) + @test res_interp2≈res rtol=1e-11 + + _, res_bs = adjoint_sensitivities(sol_mm, alg, t = ts, dg_discrete = dg, + abstol = 1e-14, reltol = 1e-14, + sensealg = BacksolveAdjoint(checkpointing = false)) + @test res_bs≈res rtol=1e-11 + _, res_bs2 = adjoint_sensitivities(sol_mm, alg, t = ts, dg_discrete = dg, + abstol = 1e-14, reltol = 1e-14, + sensealg = BacksolveAdjoint(checkpointing = true), + checkpoints = sol_mm.t) + @test res_bs2≈res rtol=1e-11 + + @info "continuous cost" + g_cont(u, p, t) = (sum(u) .^ 2) ./ 2 + dg_cont(out, u, p, t) = out .= sum(u) + _, easy_res_cont = adjoint_sensitivities(sol_mm, alg, dg_continuous = dg_cont, + g = g_cont, + abstol = 1e-10, reltol = 1e-10, + sensealg = QuadratureAdjoint()) + function G_cont(p) + tmp_prob_mm = remake(prob_mm, u0 = eltype(p).(prob_mm.u0), p = p, + tspan = eltype(p).(prob_mm.tspan)) + sol = solve(tmp_prob_mm, Rodas5(autodiff = false), abstol = 1e-14, + reltol = 1e-14) + res, err = quadgk((t) -> (sum(sol(t)) .^ 2) ./ 2, prob_mm.tspan..., + atol = 1e-14, rtol = 1e-10) + res + end + reference_sol_cont = ForwardDiff.gradient(G_cont, p) + @test easy_res_cont'≈reference_sol_cont rtol=1e-3 end - reference_sol_cont = ForwardDiff.gradient(G_cont, p) - @test easy_res_cont' ≈ reference_sol_cont rtol=1e-3 - end - - @testset "Singular mass matrix" begin - function rober(du,u,p,t) - y₁,y₂,y₃ = u - k₁,k₂,k₃ = p - du[1] = -k₁*y₁+k₃*y₂*y₃ - du[2] = k₁*y₁-k₂*y₂^2-k₃*y₂*y₃ - du[3] = y₁ + y₂ + y₃ - 1 - nothing - end - function rober(u,p,t) - y₁,y₂,y₃ = u - k₁,k₂,k₃ = p - return [-k₁*y₁+k₃*y₂*y₃, - k₁*y₁-k₂*y₂^2-k₃*y₂*y₃, - y₁ + y₂ + y₃ - 1] - end - M = [1. 0 0 - 0 1. 0 - 0 0 0] - for iip in [true, false] - f = ODEFunction{iip}(rober,mass_matrix=M) - p = [0.04,3e7,1e4] - - prob_singular_mm = ODEProblem(f,[1.0,0.0,0.0],(0.0,100),p) - sol_singular_mm = solve(prob_singular_mm,Rodas5(autodiff=false),reltol=1e-12,abstol=1e-12) - ts = [50, sol_singular_mm.t[end]] - dg_singular(out,u,p,t,i) = (fill!(out, 0); out[end] = 1) - _, res = adjoint_sensitivities(sol_singular_mm,alg,t=ts,dg_discrete=dg_singular,abstol=1e-8,reltol=1e-8,sensealg=QuadratureAdjoint(),maxiters=Int(1e6)) - reference_sol = ForwardDiff.gradient(p->G(p, prob_singular_mm, ts, sol->sum(last, sol.u)), vec(p)) - @test res' ≈ reference_sol rtol = 1e-5 - - _, res_interp = adjoint_sensitivities(sol_singular_mm,alg,t=ts,dg_discrete=dg_singular,abstol=1e-8,reltol=1e-8,sensealg=InterpolatingAdjoint(),maxiters=Int(1e6)) - @test res_interp ≈ res rtol = 1e-5 - _, res_interp2 = adjoint_sensitivities(sol_singular_mm,alg,t=ts,dg_discrete=dg_singular,abstol=1e-8,reltol=1e-8,sensealg=InterpolatingAdjoint(checkpointing=true),checkpoints=sol_singular_mm.t[1:10:end]) - @test res_interp2 ≈ res rtol = 1e-5 - - # backsolve doesn't work - _, res_bs = adjoint_sensitivities(sol_singular_mm,alg,t=ts,dg_discrete=dg_singular,abstol=1e-8,reltol=1e-8,sensealg=BacksolveAdjoint(checkpointing=false)) - @test_broken res_bs ≈ res rtol = 1e-5 - _, res_bs2 = adjoint_sensitivities(sol_singular_mm,alg,t=ts,dg_discrete=dg_singular,abstol=1e-8,reltol=1e-8,sensealg=BacksolveAdjoint(checkpointing=true),checkpoints=sol_singular_mm.t) - @test_broken res_bs2 ≈ res rtol = 1e-5 + + @testset "Singular mass matrix" begin + function rober(du, u, p, t) + y₁, y₂, y₃ = u + k₁, k₂, k₃ = p + du[1] = -k₁ * y₁ + k₃ * y₂ * y₃ + du[2] = k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃ + du[3] = y₁ + y₂ + y₃ - 1 + nothing + end + function rober(u, p, t) + y₁, y₂, y₃ = u + k₁, k₂, k₃ = p + return [-k₁ * y₁ + k₃ * y₂ * y₃, + k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃, + y₁ + y₂ + y₃ - 1] + end + M = [1.0 0 0 + 0 1.0 0 + 0 0 0] + for iip in [true, false] + f = ODEFunction{iip}(rober, mass_matrix = M) + p = [0.04, 3e7, 1e4] + + prob_singular_mm = ODEProblem(f, [1.0, 0.0, 0.0], (0.0, 100), p) + sol_singular_mm = solve(prob_singular_mm, Rodas5(autodiff = false), + reltol = 1e-12, abstol = 1e-12) + ts = [50, sol_singular_mm.t[end]] + dg_singular(out, u, p, t, i) = (fill!(out, 0); out[end] = 1) + _, res = adjoint_sensitivities(sol_singular_mm, alg, t = ts, + dg_discrete = dg_singular, abstol = 1e-8, + reltol = 1e-8, sensealg = QuadratureAdjoint(), + maxiters = Int(1e6)) + reference_sol = ForwardDiff.gradient(p -> G(p, prob_singular_mm, ts, + sol -> sum(last, sol.u)), vec(p)) + @test res'≈reference_sol rtol=1e-5 + + _, res_interp = adjoint_sensitivities(sol_singular_mm, alg, t = ts, + dg_discrete = dg_singular, abstol = 1e-8, + reltol = 1e-8, + sensealg = InterpolatingAdjoint(), + maxiters = Int(1e6)) + @test res_interp≈res rtol=1e-5 + _, res_interp2 = adjoint_sensitivities(sol_singular_mm, alg, t = ts, + dg_discrete = dg_singular, abstol = 1e-8, + reltol = 1e-8, + sensealg = InterpolatingAdjoint(checkpointing = true), + checkpoints = sol_singular_mm.t[1:10:end]) + @test res_interp2≈res rtol=1e-5 + + # backsolve doesn't work + _, res_bs = adjoint_sensitivities(sol_singular_mm, alg, t = ts, + dg_discrete = dg_singular, abstol = 1e-8, + reltol = 1e-8, + sensealg = BacksolveAdjoint(checkpointing = false)) + @test_broken res_bs≈res rtol=1e-5 + _, res_bs2 = adjoint_sensitivities(sol_singular_mm, alg, t = ts, + dg_discrete = dg_singular, abstol = 1e-8, + reltol = 1e-8, + sensealg = BacksolveAdjoint(checkpointing = true), + checkpoints = sol_singular_mm.t) + @test_broken res_bs2≈res rtol=1e-5 + end end - end end diff --git a/test/adjoint_param.jl b/test/adjoint_param.jl index fd729af7f..7c67ff441 100644 --- a/test/adjoint_param.jl +++ b/test/adjoint_param.jl @@ -8,96 +8,96 @@ using Zygote function pendulum_eom(dx, x, p, t) dx[1] = p[1] * x[2] - dx[2] = -sin(x[1]) + (-p[1]*sin(x[1]) + p[2]*x[2]) # Second term is a simple controller that stabilizes π + dx[2] = -sin(x[1]) + (-p[1] * sin(x[1]) + p[2] * x[2]) # Second term is a simple controller that stabilizes π end x0 = [0.1, 0.0] tspan = (0.0, 10.0) p = [1.0, -24.05, -19.137] prob = ODEProblem(pendulum_eom, x0, tspan, p) -sol = solve(prob, Vern9(), abstol=1e-8, reltol=1e-8) +sol = solve(prob, Vern9(), abstol = 1e-8, reltol = 1e-8) -g(x, p, t) = 1.0*(x[1] - π)^2 + 1.0*x[2]^2 + 5.0*(-p[1]*sin(x[1]) + p[2]*x[2])^2 +g(x, p, t) = 1.0 * (x[1] - π)^2 + 1.0 * x[2]^2 + 5.0 * (-p[1] * sin(x[1]) + p[2] * x[2])^2 dgdu(out, y, p, t) = ForwardDiff.gradient!(out, y -> g(y, p, t), y) dgdp(out, y, p, t) = ForwardDiff.gradient!(out, p -> g(y, p, t), p) -res_interp = adjoint_sensitivities(sol,Vern9(),g,nothing,(dgdu, dgdp),abstol=1e-8, - reltol=1e-8,iabstol=1e-8,ireltol=1e-8, sensealg=InterpolatingAdjoint()) -res_quad = adjoint_sensitivities(sol,Vern9(),g,nothing,(dgdu, dgdp),abstol=1e-8, - reltol=1e-8,iabstol=1e-8,ireltol=1e-8, sensealg=QuadratureAdjoint()) +res_interp = adjoint_sensitivities(sol, Vern9(), g, nothing, (dgdu, dgdp), abstol = 1e-8, + reltol = 1e-8, iabstol = 1e-8, ireltol = 1e-8, + sensealg = InterpolatingAdjoint()) +res_quad = adjoint_sensitivities(sol, Vern9(), g, nothing, (dgdu, dgdp), abstol = 1e-8, + reltol = 1e-8, iabstol = 1e-8, ireltol = 1e-8, + sensealg = QuadratureAdjoint()) #res_back = adjoint_sensitivities(sol,Vern9(),g,nothing,(dgdu, dgdp),abstol=1e-8, # reltol=1e-8,iabstol=1e-8,ireltol=1e-8, sensealg=BacksolveAdjoint(checkpointing=true), sol=sol.t) # it's blowing up function G(p) - tmp_prob = remake(prob,p=p,u0=convert.(eltype(p), prob.u0)) - sol = solve(tmp_prob,Vern9(),abstol=1e-8,reltol=1e-8) - res,err = quadgk((t)-> g(sol(t), p, t), 0.0,10.0,atol=1e-8,rtol=1e-8) + tmp_prob = remake(prob, p = p, u0 = convert.(eltype(p), prob.u0)) + sol = solve(tmp_prob, Vern9(), abstol = 1e-8, reltol = 1e-8) + res, err = quadgk((t) -> g(sol(t), p, t), 0.0, 10.0, atol = 1e-8, rtol = 1e-8) res end -res2 = ForwardDiff.gradient(G,p) +res2 = ForwardDiff.gradient(G, p) -@test res_interp[2]' ≈ res2 atol=1e-5 -@test res_quad[2]' ≈ res2 atol=1e-5 +@test res_interp[2]'≈res2 atol=1e-5 +@test res_quad[2]'≈res2 atol=1e-5 -p = [2.0,3.0] +p = [2.0, 3.0] u0 = [2.0] -function f(du,u,p,t) - du[1] = -u[1]*p[1]-p[2] +function f(du, u, p, t) + du[1] = -u[1] * p[1] - p[2] end -prob = ODEProblem(f,u0,(0.0,1.0),p) -sol = solve(prob,Tsit5(),abstol=1e-10,reltol=1e-10); +prob = ODEProblem(f, u0, (0.0, 1.0), p) +sol = solve(prob, Tsit5(), abstol = 1e-10, reltol = 1e-10); -g(u,p,t) = -u[1]*p[1]-p[2] +g(u, p, t) = -u[1] * p[1] - p[2] dgdu(out, y, p, t) = ForwardDiff.gradient!(out, y -> g(y, p, t), y) dgdp(out, y, p, t) = ForwardDiff.gradient!(out, p -> g(y, p, t), p) -du0,dp = adjoint_sensitivities(sol,Vern9(),g,nothing,(dgdu,dgdp);abstol=1e-10,reltol=1e-10) +du0, dp = adjoint_sensitivities(sol, Vern9(), g, nothing, (dgdu, dgdp); abstol = 1e-10, + reltol = 1e-10) function G(p) - tmp_prob = remake(prob,p=p,u0=convert.(eltype(p), prob.u0)) - sol = solve(tmp_prob,Vern9(),abstol=1e-8,reltol=1e-8) - res,err = quadgk((t)-> g(sol(t), p, t), 0.0,10.0,atol=1e-8,rtol=1e-8) + tmp_prob = remake(prob, p = p, u0 = convert.(eltype(p), prob.u0)) + sol = solve(tmp_prob, Vern9(), abstol = 1e-8, reltol = 1e-8) + res, err = quadgk((t) -> g(sol(t), p, t), 0.0, 10.0, atol = 1e-8, rtol = 1e-8) res end -res2 = ForwardDiff.gradient(G,p) +res2 = ForwardDiff.gradient(G, p) -@test dp' ≈ res2 atol=1e-5 +@test dp'≈res2 atol=1e-5 function model(p) N_oscillators = 30 u0 = repeat([0.0; 1.0], 1, N_oscillators) # size(u0) = (2, 30) function du!(du, u, p, t) - W, b = p # Parameters - dy = @view du[1,:] # 30 elements - dy′ = @view du[2,:] - y = @view u[1,:] - y′= @view u[2,:] + W, b = p # Parameters + dy = @view du[1, :] # 30 elements + dy′ = @view du[2, :] + y = @view u[1, :] + y′ = @view u[2, :] @. dy′ = -y * W - @. dy = y′ * b + @. dy = y′ * b end - output = solve( - ODEProblem( - du!, - u0, - (0.0, 10.0), - p, - jac = true, - abstol = 1e-12, - reltol = 1e-12), - Tsit5(), - jac = true, - saveat = collect(0:0.1:7), - sensealg = QuadratureAdjoint(), - ) + output = solve(ODEProblem(du!, + u0, + (0.0, 10.0), + p, + jac = true, + abstol = 1e-12, + reltol = 1e-12), + Tsit5(), + jac = true, + saveat = collect(0:0.1:7), + sensealg = QuadratureAdjoint()) return Array(output[1, :, :]) # only return y, not y′ end -p=[1.5, 0.1] +p = [1.5, 0.1] y = model(p) loss(p) = sum(model(p)) -dp1 = Zygote.gradient(loss,p)[1] -dp2 = ForwardDiff.gradient(loss,p) +dp1 = Zygote.gradient(loss, p)[1] +dp2 = ForwardDiff.gradient(loss, p) @test dp1 ≈ dp2 diff --git a/test/adjoint_shapes.jl b/test/adjoint_shapes.jl index 96f9d04ae..2df20711d 100644 --- a/test/adjoint_shapes.jl +++ b/test/adjoint_shapes.jl @@ -1,15 +1,15 @@ using OrdinaryDiffEq, SciMLSensitivity, Zygote -tspan = (0., 1.) +tspan = (0.0, 1.0) X = randn(3, 4) p = randn(3, 4) -f(u,p,t) = u .* p -f(du,u,p,t) = (du .= u .* p) +f(u, p, t) = u .* p +f(du, u, p, t) = (du .= u .* p) prob_ube = ODEProblem{false}(f, X, tspan, p) -Zygote.gradient(p->sum(solve(prob_ube, Midpoint(), u0 = X, p = p)),p) +Zygote.gradient(p -> sum(solve(prob_ube, Midpoint(), u0 = X, p = p)), p) prob_ube = ODEProblem{true}(f, X, tspan, p) -Zygote.gradient(p->sum(solve(prob_ube, Midpoint(), u0 = X, p = p)),p) +Zygote.gradient(p -> sum(solve(prob_ube, Midpoint(), u0 = X, p = p)), p) function aug_dynamics!(dz, z, K, t) x = @view z[2:end] @@ -20,20 +20,15 @@ end policy_params = ones(2, 2) z0 = zeros(3) -fwd_sol = solve( - ODEProblem(aug_dynamics!, z0, (0.0, 1.0), policy_params), - Tsit5(), - u0 = z0, - p = policy_params) +fwd_sol = solve(ODEProblem(aug_dynamics!, z0, (0.0, 1.0), policy_params), + Tsit5(), + u0 = z0, + p = policy_params) -solve( - ODEAdjointProblem( - fwd_sol, - InterpolatingAdjoint(), - (out, x, p, t, i) -> (out .= 1), - [1.0], - ),Tsit5() -) +solve(ODEAdjointProblem(fwd_sol, + InterpolatingAdjoint(), + (out, x, p, t, i) -> (out .= 1), + [1.0]), Tsit5()) A = ones(2, 2) B = ones(2, 2) @@ -49,17 +44,11 @@ end policy_params = ones(2, 2) z0 = zeros(3) -fwd_sol = solve( - ODEProblem(aug_dynamics!, z0, (0.0, 1.0), policy_params), - u0 = z0, - p = policy_params, -) - -solve( - ODEAdjointProblem( - fwd_sol, - InterpolatingAdjoint(), - (out, x, p, t, i) -> (out .= 1), - [1.0], - ), -) +fwd_sol = solve(ODEProblem(aug_dynamics!, z0, (0.0, 1.0), policy_params), + u0 = z0, + p = policy_params) + +solve(ODEAdjointProblem(fwd_sol, + InterpolatingAdjoint(), + (out, x, p, t, i) -> (out .= 1), + [1.0])) diff --git a/test/alternative_ad_frontend.jl b/test/alternative_ad_frontend.jl index ea77c5c83..5eb147888 100644 --- a/test/alternative_ad_frontend.jl +++ b/test/alternative_ad_frontend.jl @@ -1,95 +1,117 @@ -using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, ReverseDiff, Tracker -using Test - -prob = ODEProblem((u,p,t)->u .* p,[2.0],(0.0,1.0),[3.0]) - -struct senseloss; sense end -(f::senseloss)(u0p) = sum(solve(prob,Tsit5(),u0=u0p[1:1],p=u0p[2:2],abstol=1e-12, - reltol=1e-12,saveat=0.1,sensealg=f.sense)) -loss(u0p) = sum(solve(prob,Tsit5(),u0=u0p[1:1],p=u0p[2:2],abstol=1e-12,reltol=1e-12,saveat=0.1)) -u0p = [2.0,3.0] - -dup = Zygote.gradient(senseloss(InterpolatingAdjoint()),u0p)[1] - -@test ReverseDiff.gradient(senseloss(InterpolatingAdjoint()),u0p) ≈ dup -@test_broken ReverseDiff.gradient(senseloss(ReverseDiffAdjoint()),u0p) ≈ dup -@test ReverseDiff.gradient(senseloss(TrackerAdjoint()),u0p) ≈ dup -@test ReverseDiff.gradient(senseloss(ForwardDiffSensitivity()),u0p) ≈ dup -@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError ReverseDiff.gradient(senseloss(ForwardSensitivity()),u0p) ≈ dup - -@test Tracker.gradient(senseloss(InterpolatingAdjoint()),u0p)[1] ≈ dup -@test_broken Tracker.gradient(senseloss(ReverseDiffAdjoint()),u0p)[1] ≈ dup -@test Tracker.gradient(senseloss(TrackerAdjoint()),u0p)[1] ≈ dup -@test Tracker.gradient(senseloss(ForwardDiffSensitivity()),u0p)[1] ≈ dup -@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError Tracker.gradient(senseloss(ForwardSensitivity()),u0p)[1] ≈ dup - -@test ForwardDiff.gradient(senseloss(InterpolatingAdjoint()),u0p) ≈ dup - -struct senseloss2; sense end -prob2 = ODEProblem((du,u,p,t)->du .= u .* p,[2.0],(0.0,1.0),[3.0]) - -(f::senseloss2)(u0p) = sum(solve(prob2,Tsit5(),u0=u0p[1:1],p=u0p[2:2],abstol=1e-12, - reltol=1e-12,saveat=0.1,sensealg=f.sense)) - -u0p = [2.0,3.0] - -dup = Zygote.gradient(senseloss2(InterpolatingAdjoint()),u0p)[1] - -@test ReverseDiff.gradient(senseloss2(InterpolatingAdjoint()),u0p) ≈ dup -@test ReverseDiff.gradient(senseloss2(ReverseDiffAdjoint()),u0p) ≈ dup -@test ReverseDiff.gradient(senseloss2(TrackerAdjoint()),u0p) ≈ dup -@test ReverseDiff.gradient(senseloss2(ForwardDiffSensitivity()),u0p) ≈ dup -@test_broken ReverseDiff.gradient(senseloss2(ForwardSensitivity()),u0p) ≈ dup - -@test Tracker.gradient(senseloss2(InterpolatingAdjoint()),u0p)[1] ≈ dup -@test Tracker.gradient(senseloss2(ReverseDiffAdjoint()),u0p)[1] ≈ dup -@test Tracker.gradient(senseloss2(TrackerAdjoint()),u0p)[1] ≈ dup -@test Tracker.gradient(senseloss2(ForwardDiffSensitivity()),u0p)[1] ≈ dup -@test_broken Tracker.gradient(senseloss2(ForwardSensitivity()),u0p)[1] ≈ dup - -@test ForwardDiff.gradient(senseloss2(InterpolatingAdjoint()),u0p) ≈ dup - -struct senseloss3; sense end -(f::senseloss3)(u0p) = sum(solve(prob2,Tsit5(),p=u0p,abstol=1e-12, - reltol=1e-12,saveat=0.1,sensealg=f.sense)) - -u0p = [3.0] - -dup = Zygote.gradient(senseloss3(InterpolatingAdjoint()),u0p)[1] - -@test ReverseDiff.gradient(senseloss3(InterpolatingAdjoint()),u0p) ≈ dup -@test ReverseDiff.gradient(senseloss3(ReverseDiffAdjoint()),u0p) ≈ dup -@test ReverseDiff.gradient(senseloss3(TrackerAdjoint()),u0p) ≈ dup -@test ReverseDiff.gradient(senseloss3(ForwardDiffSensitivity()),u0p) ≈ dup -@test ReverseDiff.gradient(senseloss3(ForwardSensitivity()),u0p) ≈ dup - -@test Tracker.gradient(senseloss3(InterpolatingAdjoint()),u0p)[1] ≈ dup -@test Tracker.gradient(senseloss3(ReverseDiffAdjoint()),u0p)[1] ≈ dup -@test Tracker.gradient(senseloss3(TrackerAdjoint()),u0p)[1] ≈ dup -@test Tracker.gradient(senseloss3(ForwardDiffSensitivity()),u0p)[1] ≈ dup -@test Tracker.gradient(senseloss3(ForwardSensitivity()),u0p)[1] ≈ dup - -@test ForwardDiff.gradient(senseloss3(InterpolatingAdjoint()),u0p) ≈ dup - - -struct senseloss4; sense end -(f::senseloss4)(u0p) = sum(solve(prob,Tsit5(),p=u0p,abstol=1e-12, - reltol=1e-12,saveat=0.1,sensealg=f.sense)) - -u0p = [3.0] - -dup = Zygote.gradient(senseloss4(InterpolatingAdjoint()),u0p)[1] - -@test ReverseDiff.gradient(senseloss4(InterpolatingAdjoint()),u0p) ≈ dup -@test_broken ReverseDiff.gradient(senseloss4(ReverseDiffAdjoint()),u0p) ≈ dup -@test ReverseDiff.gradient(senseloss4(TrackerAdjoint()),u0p) ≈ dup -@test ReverseDiff.gradient(senseloss4(ForwardDiffSensitivity()),u0p) ≈ dup -@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError ReverseDiff.gradient(senseloss4(ForwardSensitivity()),u0p) ≈ dup - -@test Tracker.gradient(senseloss4(InterpolatingAdjoint()),u0p)[1] ≈ dup -@test_broken Tracker.gradient(senseloss4(ReverseDiffAdjoint()),u0p)[1] ≈ dup -@test Tracker.gradient(senseloss4(TrackerAdjoint()),u0p)[1] ≈ dup -@test Tracker.gradient(senseloss4(ForwardDiffSensitivity()),u0p)[1] ≈ dup -@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError Tracker.gradient(senseloss4(ForwardSensitivity()),u0p)[1] ≈ dup - -@test ForwardDiff.gradient(senseloss4(InterpolatingAdjoint()),u0p) ≈ dup \ No newline at end of file +using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, ReverseDiff, Tracker +using Test + +prob = ODEProblem((u, p, t) -> u .* p, [2.0], (0.0, 1.0), [3.0]) + +struct senseloss + sense::Any +end +function (f::senseloss)(u0p) + sum(solve(prob, Tsit5(), u0 = u0p[1:1], p = u0p[2:2], abstol = 1e-12, + reltol = 1e-12, saveat = 0.1, sensealg = f.sense)) +end +function loss(u0p) + sum(solve(prob, Tsit5(), u0 = u0p[1:1], p = u0p[2:2], abstol = 1e-12, reltol = 1e-12, + saveat = 0.1)) +end +u0p = [2.0, 3.0] + +dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1] + +@test ReverseDiff.gradient(senseloss(InterpolatingAdjoint()), u0p) ≈ dup +@test_broken ReverseDiff.gradient(senseloss(ReverseDiffAdjoint()), u0p) ≈ dup +@test ReverseDiff.gradient(senseloss(TrackerAdjoint()), u0p) ≈ dup +@test ReverseDiff.gradient(senseloss(ForwardDiffSensitivity()), u0p) ≈ dup +@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError ReverseDiff.gradient(senseloss(ForwardSensitivity()), + u0p)≈dup + +@test Tracker.gradient(senseloss(InterpolatingAdjoint()), u0p)[1] ≈ dup +@test_broken Tracker.gradient(senseloss(ReverseDiffAdjoint()), u0p)[1] ≈ dup +@test Tracker.gradient(senseloss(TrackerAdjoint()), u0p)[1] ≈ dup +@test Tracker.gradient(senseloss(ForwardDiffSensitivity()), u0p)[1] ≈ dup +@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError Tracker.gradient(senseloss(ForwardSensitivity()), + u0p)[1]≈dup + +@test ForwardDiff.gradient(senseloss(InterpolatingAdjoint()), u0p) ≈ dup + +struct senseloss2 + sense::Any +end +prob2 = ODEProblem((du, u, p, t) -> du .= u .* p, [2.0], (0.0, 1.0), [3.0]) + +function (f::senseloss2)(u0p) + sum(solve(prob2, Tsit5(), u0 = u0p[1:1], p = u0p[2:2], abstol = 1e-12, + reltol = 1e-12, saveat = 0.1, sensealg = f.sense)) +end + +u0p = [2.0, 3.0] + +dup = Zygote.gradient(senseloss2(InterpolatingAdjoint()), u0p)[1] + +@test ReverseDiff.gradient(senseloss2(InterpolatingAdjoint()), u0p) ≈ dup +@test ReverseDiff.gradient(senseloss2(ReverseDiffAdjoint()), u0p) ≈ dup +@test ReverseDiff.gradient(senseloss2(TrackerAdjoint()), u0p) ≈ dup +@test ReverseDiff.gradient(senseloss2(ForwardDiffSensitivity()), u0p) ≈ dup +@test_broken ReverseDiff.gradient(senseloss2(ForwardSensitivity()), u0p) ≈ dup + +@test Tracker.gradient(senseloss2(InterpolatingAdjoint()), u0p)[1] ≈ dup +@test Tracker.gradient(senseloss2(ReverseDiffAdjoint()), u0p)[1] ≈ dup +@test Tracker.gradient(senseloss2(TrackerAdjoint()), u0p)[1] ≈ dup +@test Tracker.gradient(senseloss2(ForwardDiffSensitivity()), u0p)[1] ≈ dup +@test_broken Tracker.gradient(senseloss2(ForwardSensitivity()), u0p)[1] ≈ dup + +@test ForwardDiff.gradient(senseloss2(InterpolatingAdjoint()), u0p) ≈ dup + +struct senseloss3 + sense::Any +end +function (f::senseloss3)(u0p) + sum(solve(prob2, Tsit5(), p = u0p, abstol = 1e-12, + reltol = 1e-12, saveat = 0.1, sensealg = f.sense)) +end + +u0p = [3.0] + +dup = Zygote.gradient(senseloss3(InterpolatingAdjoint()), u0p)[1] + +@test ReverseDiff.gradient(senseloss3(InterpolatingAdjoint()), u0p) ≈ dup +@test ReverseDiff.gradient(senseloss3(ReverseDiffAdjoint()), u0p) ≈ dup +@test ReverseDiff.gradient(senseloss3(TrackerAdjoint()), u0p) ≈ dup +@test ReverseDiff.gradient(senseloss3(ForwardDiffSensitivity()), u0p) ≈ dup +@test ReverseDiff.gradient(senseloss3(ForwardSensitivity()), u0p) ≈ dup + +@test Tracker.gradient(senseloss3(InterpolatingAdjoint()), u0p)[1] ≈ dup +@test Tracker.gradient(senseloss3(ReverseDiffAdjoint()), u0p)[1] ≈ dup +@test Tracker.gradient(senseloss3(TrackerAdjoint()), u0p)[1] ≈ dup +@test Tracker.gradient(senseloss3(ForwardDiffSensitivity()), u0p)[1] ≈ dup +@test Tracker.gradient(senseloss3(ForwardSensitivity()), u0p)[1] ≈ dup + +@test ForwardDiff.gradient(senseloss3(InterpolatingAdjoint()), u0p) ≈ dup + +struct senseloss4 + sense::Any +end +function (f::senseloss4)(u0p) + sum(solve(prob, Tsit5(), p = u0p, abstol = 1e-12, + reltol = 1e-12, saveat = 0.1, sensealg = f.sense)) +end + +u0p = [3.0] + +dup = Zygote.gradient(senseloss4(InterpolatingAdjoint()), u0p)[1] + +@test ReverseDiff.gradient(senseloss4(InterpolatingAdjoint()), u0p) ≈ dup +@test_broken ReverseDiff.gradient(senseloss4(ReverseDiffAdjoint()), u0p) ≈ dup +@test ReverseDiff.gradient(senseloss4(TrackerAdjoint()), u0p) ≈ dup +@test ReverseDiff.gradient(senseloss4(ForwardDiffSensitivity()), u0p) ≈ dup +@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError ReverseDiff.gradient(senseloss4(ForwardSensitivity()), + u0p)≈dup + +@test Tracker.gradient(senseloss4(InterpolatingAdjoint()), u0p)[1] ≈ dup +@test_broken Tracker.gradient(senseloss4(ReverseDiffAdjoint()), u0p)[1] ≈ dup +@test Tracker.gradient(senseloss4(TrackerAdjoint()), u0p)[1] ≈ dup +@test Tracker.gradient(senseloss4(ForwardDiffSensitivity()), u0p)[1] ≈ dup +@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError Tracker.gradient(senseloss4(ForwardSensitivity()), + u0p)[1]≈dup + +@test ForwardDiff.gradient(senseloss4(InterpolatingAdjoint()), u0p) ≈ dup diff --git a/test/array_partitions.jl b/test/array_partitions.jl index 715ab59be..3a044e05c 100644 --- a/test/array_partitions.jl +++ b/test/array_partitions.jl @@ -1,38 +1,28 @@ import OrdinaryDiffEq import DiffEqBase: DynamicalODEProblem import SciMLSensitivity: - solve, - ODEProblem, - ODEAdjointProblem, - InterpolatingAdjoint, - ZygoteVJP, - ReverseDiffVJP + solve, + ODEProblem, + ODEAdjointProblem, + InterpolatingAdjoint, + ZygoteVJP, + ReverseDiffVJP import RecursiveArrayTools: ArrayPartition -sol = solve( - DynamicalODEProblem( - (v, x, p, t) -> [0.0, 0.0], +sol = solve(DynamicalODEProblem((v, x, p, t) -> [0.0, 0.0], - # ERROR: LoadError: type Nothing has no field x - # (v, x, p, t) -> [0.0, 0.0], + # ERROR: LoadError: type Nothing has no field x + # (v, x, p, t) -> [0.0, 0.0], - # ERROR: LoadError: MethodError: no method matching ndims(::Type{Nothing}) - (v, x, p, t) -> v, + # ERROR: LoadError: MethodError: no method matching ndims(::Type{Nothing}) + (v, x, p, t) -> v, [0.0, 0.0], + [0.0, 0.0], + (0.0, 1.0)), OrdinaryDiffEq.Tsit5()) - [0.0, 0.0], - [0.0, 0.0], - (0.0, 1.0), - ),OrdinaryDiffEq.Tsit5() -) - -solve( - ODEAdjointProblem( - sol, - InterpolatingAdjoint(autojacvec=ZygoteVJP(allow_nothing=true)), - [sol.t[end]], - (out, x, p, t, i) -> (out .= 0) - ),OrdinaryDiffEq.Tsit5() -) +solve(ODEAdjointProblem(sol, + InterpolatingAdjoint(autojacvec = ZygoteVJP(allow_nothing = true)), + [sol.t[end]], + (out, x, p, t, i) -> (out .= 0)), OrdinaryDiffEq.Tsit5()) dyn_v(v_ap, x_ap, p, t) = ArrayPartition(zeros(), [0.0]) # Originally, I imagined that this may be a bug in Zygote, and it still may be, but I tried doing a pullback on this @@ -50,27 +40,20 @@ end v0 = [-1.0] x0 = [0.75] -sol = solve( - DynamicalODEProblem( - dyn_v, - dyn_x, - ArrayPartition(zeros(), v0), - ArrayPartition(zeros(), x0), - (0.0, 1.0), - zeros() - ),OrdinaryDiffEq.Tsit5(), - # Without setting parameters, we end up with https://github.com/SciML/DifferentialEquations.jl/issues/679 again. - p = zeros() -) +sol = solve(DynamicalODEProblem(dyn_v, + dyn_x, + ArrayPartition(zeros(), v0), + ArrayPartition(zeros(), x0), + (0.0, 1.0), + zeros()), OrdinaryDiffEq.Tsit5(), + # Without setting parameters, we end up with https://github.com/SciML/DifferentialEquations.jl/issues/679 again. + p = zeros()) g = ArrayPartition(ArrayPartition(zeros(), zero(v0)), ArrayPartition(zeros(), zero(x0))) -bwd_sol = solve( - ODEAdjointProblem( - sol, - InterpolatingAdjoint(autojacvec=ZygoteVJP(allow_nothing=true)), - # Also fails, but due to a different bug: - # InterpolatingAdjoint(autojacvec=ReverseDiffVJP()), - [sol.t[end]], - (out, x, p, t, i) -> (out[:] = g) - ),OrdinaryDiffEq.Tsit5() -) +bwd_sol = solve(ODEAdjointProblem(sol, + InterpolatingAdjoint(autojacvec = ZygoteVJP(allow_nothing = true)), + # Also fails, but due to a different bug: + # InterpolatingAdjoint(autojacvec=ReverseDiffVJP()), + [sol.t[end]], + (out, x, p, t, i) -> (out[:] = g)), + OrdinaryDiffEq.Tsit5()) diff --git a/test/autodiff_events.jl b/test/autodiff_events.jl index 4d52ca7c8..6b50fd6a2 100644 --- a/test/autodiff_events.jl +++ b/test/autodiff_events.jl @@ -2,61 +2,67 @@ using SciMLSensitivity using OrdinaryDiffEq, Calculus, Test using Zygote -function f(du,u,p,t) - du[1] = u[2] - du[2] = -p[1] +function f(du, u, p, t) + du[1] = u[2] + du[2] = -p[1] end -function condition(u,t,integrator) # Event when event_f(u,t) == 0 - u[1] +function condition(u, t, integrator) # Event when event_f(u,t) == 0 + u[1] end function affect!(integrator) - @show integrator.t - println("bounced.") - integrator.u[2] = -integrator.p[2]*integrator.u[2] + @show integrator.t + println("bounced.") + integrator.u[2] = -integrator.p[2] * integrator.u[2] end cb = ContinuousCallback(condition, affect!) p = [9.8, 0.8] -prob = ODEProblem(f,eltype(p).([1.0,0.0]),eltype(p).((0.0,1.0)),copy(p)) +prob = ODEProblem(f, eltype(p).([1.0, 0.0]), eltype(p).((0.0, 1.0)), copy(p)) function test_f(p) - _prob = remake(prob, p=p) - solve(_prob,Tsit5(),abstol=1e-14,reltol=1e-14,callback=cb,save_everystep=false)[end] + _prob = remake(prob, p = p) + solve(_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14, callback = cb, + save_everystep = false)[end] end -findiff = Calculus.finite_difference_jacobian(test_f,p) +findiff = Calculus.finite_difference_jacobian(test_f, p) findiff using ForwardDiff -ad = ForwardDiff.jacobian(test_f,p) +ad = ForwardDiff.jacobian(test_f, p) ad @test ad ≈ findiff -function test_f2(p, sensealg=ForwardDiffSensitivity(), controller=nothing, alg=Tsit5()) - _prob = remake(prob, p=p) - u = solve(_prob,alg,sensealg=sensealg,controller=controller, - abstol=1e-14,reltol=1e-14,callback=cb,save_everystep=false) - u[end][end] +function test_f2(p, sensealg = ForwardDiffSensitivity(), controller = nothing, + alg = Tsit5()) + _prob = remake(prob, p = p) + u = solve(_prob, alg, sensealg = sensealg, controller = controller, + abstol = 1e-14, reltol = 1e-14, callback = cb, save_everystep = false) + u[end][end] end @test test_f2(p) == test_f(p)[end] -g1 = Zygote.gradient(θ->test_f2(θ,ForwardDiffSensitivity()), p) -g2 = Zygote.gradient(θ->test_f2(θ,ReverseDiffAdjoint()), p) -g3 = Zygote.gradient(θ->test_f2(θ,ReverseDiffAdjoint(), IController()), p) -g4 = Zygote.gradient(θ->test_f2(θ,ReverseDiffAdjoint(), PIController(7//50, 2//25)), p) -@test_broken g5 = Zygote.gradient(θ->test_f2(θ,ReverseDiffAdjoint(), PIDController(1/18. , 1/9., 1/18.)), p) -g6 = Zygote.gradient(θ->test_f2(θ,ForwardDiffSensitivity(), - OrdinaryDiffEq.PredictiveController(), TRBDF2()), p) -@test_broken g7 = Zygote.gradient(θ->test_f2(θ,ReverseDiffAdjoint(), - OrdinaryDiffEq.PredictiveController(), TRBDF2()), p) - -@test g1[1] ≈ findiff[2,1:2] -@test g2[1] ≈ findiff[2,1:2] -@test g3[1] ≈ findiff[2,1:2] -@test g4[1] ≈ findiff[2,1:2] -@test_broken g5[1] ≈ findiff[2,1:2] -@test g6[1] ≈ findiff[2,1:2] -@test_broken g7[1] ≈ findiff[2,1:2] +g1 = Zygote.gradient(θ -> test_f2(θ, ForwardDiffSensitivity()), p) +g2 = Zygote.gradient(θ -> test_f2(θ, ReverseDiffAdjoint()), p) +g3 = Zygote.gradient(θ -> test_f2(θ, ReverseDiffAdjoint(), IController()), p) +g4 = Zygote.gradient(θ -> test_f2(θ, ReverseDiffAdjoint(), PIController(7 // 50, 2 // 25)), + p) +@test_broken g5 = Zygote.gradient(θ -> test_f2(θ, ReverseDiffAdjoint(), + PIDController(1 / 18.0, 1 / 9.0, 1 / 18.0)), + p) +g6 = Zygote.gradient(θ -> test_f2(θ, ForwardDiffSensitivity(), + OrdinaryDiffEq.PredictiveController(), TRBDF2()), p) +@test_broken g7 = Zygote.gradient(θ -> test_f2(θ, ReverseDiffAdjoint(), + OrdinaryDiffEq.PredictiveController(), + TRBDF2()), p) + +@test g1[1] ≈ findiff[2, 1:2] +@test g2[1] ≈ findiff[2, 1:2] +@test g3[1] ≈ findiff[2, 1:2] +@test g4[1] ≈ findiff[2, 1:2] +@test_broken g5[1] ≈ findiff[2, 1:2] +@test g6[1] ≈ findiff[2, 1:2] +@test_broken g7[1] ≈ findiff[2, 1:2] diff --git a/test/branching_derivatives.jl b/test/branching_derivatives.jl index c989ec2ed..da1b356ab 100644 --- a/test/branching_derivatives.jl +++ b/test/branching_derivatives.jl @@ -1,33 +1,48 @@ -using SciMLSensitivity, OrdinaryDiffEq, Zygote, Test - -function get_param(breakpoints, values, t) - for (i, tᵢ) in enumerate(breakpoints) - if t <= tᵢ - return values[i] - end - end - return values[end] -end - -function fiip(du, u, p, t) - a = get_param([1., 2., 3.], p[1:4], t) - du[1] = dx = a * u[1] - u[1] * u[2] - du[2] = dy = -a * u[2] + u[1] * u[2] -end - -p = [1., 1., 1., 1.]; u0 = [1.0;1.0] -prob = ODEProblem(fiip, u0, (0.0, 4.0), p); - -dp1 = Zygote.gradient(p->sum(solve(prob, Tsit5(), u0=u0, p=p, sensealg = ForwardDiffSensitivity(), saveat = 0.1, abstol=1e-12, reltol=1e-12)), p) -dp2 = Zygote.gradient(p->sum(solve(prob, Tsit5(), u0=u0, p=p, sensealg = ForwardDiffSensitivity(convert_tspan=true), saveat = 0.1, abstol=1e-12, reltol=1e-12)), p) -dp3 = Zygote.gradient(p->sum(solve(prob, Tsit5(), u0=u0, p=p, sensealg = ForwardSensitivity(), saveat = 0.1, abstol=1e-12, reltol=1e-12)), p) -dp4 = Zygote.gradient(p->sum(solve(prob, Tsit5(), u0=u0, p=p, saveat = 0.1, abstol=1e-12, reltol=1e-12)), p) -dp5 = Zygote.gradient(p->sum(solve(prob, Tsit5(), u0=u0, p=p, saveat = 0.1, abstol=1e-12, reltol=1e-12, sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP()))), p) -dp6 = Zygote.gradient(p->sum(solve(prob, Tsit5(), u0=u0, p=p, saveat = 0.1, abstol=1e-12, reltol=1e-12, sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))), p) - -@test dp1[1] ≈ dp2[1] -@test dp1[1] ≈ dp3[1] -@test dp1[1] ≈ dp4[1] -@test dp1[1] ≈ dp5[1] -@test sum(dp5[1]) ≈ sum(dp6[1]) -@test all(dp6[1][1:3] .== 0) +using SciMLSensitivity, OrdinaryDiffEq, Zygote, Test + +function get_param(breakpoints, values, t) + for (i, tᵢ) in enumerate(breakpoints) + if t <= tᵢ + return values[i] + end + end + return values[end] +end + +function fiip(du, u, p, t) + a = get_param([1.0, 2.0, 3.0], p[1:4], t) + du[1] = dx = a * u[1] - u[1] * u[2] + du[2] = dy = -a * u[2] + u[1] * u[2] +end + +p = [1.0, 1.0, 1.0, 1.0]; +u0 = [1.0; 1.0] +; +prob = ODEProblem(fiip, u0, (0.0, 4.0), p); + +dp1 = Zygote.gradient(p -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + sensealg = ForwardDiffSensitivity(), saveat = 0.1, + abstol = 1e-12, reltol = 1e-12)), p) +dp2 = Zygote.gradient(p -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + sensealg = ForwardDiffSensitivity(convert_tspan = true), + saveat = 0.1, abstol = 1e-12, reltol = 1e-12)), p) +dp3 = Zygote.gradient(p -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + sensealg = ForwardSensitivity(), saveat = 0.1, + abstol = 1e-12, reltol = 1e-12)), p) +dp4 = Zygote.gradient(p -> sum(solve(prob, Tsit5(), u0 = u0, p = p, saveat = 0.1, + abstol = 1e-12, reltol = 1e-12)), p) +dp5 = Zygote.gradient(p -> sum(solve(prob, Tsit5(), u0 = u0, p = p, saveat = 0.1, + abstol = 1e-12, reltol = 1e-12, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP()))), + p) +dp6 = Zygote.gradient(p -> sum(solve(prob, Tsit5(), u0 = u0, p = p, saveat = 0.1, + abstol = 1e-12, reltol = 1e-12, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true)))), + p) + +@test dp1[1] ≈ dp2[1] +@test dp1[1] ≈ dp3[1] +@test dp1[1] ≈ dp4[1] +@test dp1[1] ≈ dp5[1] +@test sum(dp5[1]) ≈ sum(dp6[1]) +@test all(dp6[1][1:3] .== 0) diff --git a/test/callback_reversediff.jl b/test/callback_reversediff.jl index a5558483b..a03d6813e 100644 --- a/test/callback_reversediff.jl +++ b/test/callback_reversediff.jl @@ -3,57 +3,58 @@ using Random Random.seed!(1234) -u0 = Float32[2.; 0.] +u0 = Float32[2.0; 0.0] datasize = 100 -tspan = (0.0f0,10.5f0) -dosetimes = [1.0,2.0,4.0,8.0] +tspan = (0.0f0, 10.5f0) +dosetimes = [1.0, 2.0, 4.0, 8.0] function affect!(integrator) - integrator.u = integrator.u.+1 + integrator.u = integrator.u .+ 1 end -cb_ = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false)) -function trueODEfunc(du,u,p,t) +cb_ = PresetTimeCallback(dosetimes, affect!, save_positions = (false, false)) +function trueODEfunc(du, u, p, t) du .= -u end -t = range(tspan[1],tspan[2],length=datasize) +t = range(tspan[1], tspan[2], length = datasize) -prob = ODEProblem(trueODEfunc,u0,tspan) -ode_data = Array(solve(prob,Tsit5(),callback=cb_,saveat=t)) -dudt2 = Chain(Dense(2,50,tanh), - Dense(50,2)) -p,re = Flux.destructure(dudt2) # use this p as the initial condition! +prob = ODEProblem(trueODEfunc, u0, tspan) +ode_data = Array(solve(prob, Tsit5(), callback = cb_, saveat = t)) +dudt2 = Chain(Dense(2, 50, tanh), + Dense(50, 2)) +p, re = Flux.destructure(dudt2) # use this p as the initial condition! -function dudt(du,u,p,t) +function dudt(du, u, p, t) du[1:2] .= -u[1:2] du[3:end] .= re(p)(u[1:2]) #re(p)(u[3:end]) end -z0 = Float32[u0;u0] -prob = ODEProblem(dudt,z0,tspan) +z0 = Float32[u0; u0] +prob = ODEProblem(dudt, z0, tspan) affect!(integrator) = integrator.u[1:2] .= integrator.u[3:end] -cb = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false)) +cb = PresetTimeCallback(dosetimes, affect!, save_positions = (false, false)) function predict_n_ode() - _prob = remake(prob,p=p) - Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cb,saveat=t,sensealg=ReverseDiffAdjoint()))[1:2,:] + _prob = remake(prob, p = p) + Array(solve(_prob, Tsit5(), u0 = z0, p = p, callback = cb, saveat = t, + sensealg = ReverseDiffAdjoint()))[1:2, :] #Array(solve(prob,Tsit5(),u0=z0,p=p,saveat=t))[1:2,:] end function loss_n_ode() pred = predict_n_ode() - loss = sum(abs2,ode_data .- pred) + loss = sum(abs2, ode_data .- pred) loss end loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE -cba = function (;doplot=false) #callback function to observe training - pred = predict_n_ode() - display(sum(abs2,ode_data .- pred)) - # plot current prediction against data - #pl = scatter(t,ode_data[1,:],label="data") - #scatter!(pl,t,pred[1,:],label="prediction") - #display(plot(pl)) - return false +cba = function (; doplot = false) #callback function to observe training + pred = predict_n_ode() + display(sum(abs2, ode_data .- pred)) + # plot current prediction against data + #pl = scatter(t,ode_data[1,:],label="data") + #scatter!(pl,t,pred[1,:],label="prediction") + #display(plot(pl)) + return false end cba() diff --git a/test/callbacks/SDE_callbacks.jl b/test/callbacks/SDE_callbacks.jl index 7fa3ec4fe..35aebc84f 100644 --- a/test/callbacks/SDE_callbacks.jl +++ b/test/callbacks/SDE_callbacks.jl @@ -6,44 +6,45 @@ reltol = 1e-12 savingtimes = 0.5 function test_SDE_callbacks() - function dt!(du, u, p, t) - x, y = u - α, β, δ, γ = p - du[1] = dx = α * x - β * x * y - du[2] = dy = -δ * y + γ * x * y - end - - function dW!(du, u, p, t) - du[1] = 0.1u[1] - du[2] = 0.1u[2] - end - - u0 = [1.0, 1.0] - tspan = (0.0, 10.0) - p = [2.2, 1.0, 2.0, 0.4] - prob_sde = SDEProblem(dt!, dW!, u0, tspan, p) - - condition(u, t, integrator) = integrator.t > 9.0 #some condition - function affect!(integrator) - #println("Callback") #some callback - end - cb = DiscreteCallback(condition, affect!, save_positions=(false, false)) - - function predict_sde(p) - return Array(solve(prob_sde, EM(), p=p, saveat=savingtimes, sensealg=ForwardDiffSensitivity(), dt=0.001, callback=cb)) - end - - loss_sde(p) = sum(abs2, x - 1 for x in predict_sde(p)) - - loss_sde(p) - @time dp = gradient(p) do p + function dt!(du, u, p, t) + x, y = u + α, β, δ, γ = p + du[1] = dx = α * x - β * x * y + du[2] = dy = -δ * y + γ * x * y + end + + function dW!(du, u, p, t) + du[1] = 0.1u[1] + du[2] = 0.1u[2] + end + + u0 = [1.0, 1.0] + tspan = (0.0, 10.0) + p = [2.2, 1.0, 2.0, 0.4] + prob_sde = SDEProblem(dt!, dW!, u0, tspan, p) + + condition(u, t, integrator) = integrator.t > 9.0 #some condition + function affect!(integrator) + #println("Callback") #some callback + end + cb = DiscreteCallback(condition, affect!, save_positions = (false, false)) + + function predict_sde(p) + return Array(solve(prob_sde, EM(), p = p, saveat = savingtimes, + sensealg = ForwardDiffSensitivity(), dt = 0.001, callback = cb)) + end + + loss_sde(p) = sum(abs2, x - 1 for x in predict_sde(p)) + loss_sde(p) - end + @time dp = gradient(p) do p + loss_sde(p) + end - @test !iszero(dp[1]) + @test !iszero(dp[1]) end @testset "SDEs" begin - println("SDEs") - test_SDE_callbacks() -end \ No newline at end of file + println("SDEs") + test_SDE_callbacks() +end diff --git a/test/callbacks/continuous_callbacks.jl b/test/callbacks/continuous_callbacks.jl index ac98b2f30..771ff8c19 100644 --- a/test/callbacks/continuous_callbacks.jl +++ b/test/callbacks/continuous_callbacks.jl @@ -5,208 +5,243 @@ abstol = 1e-12 reltol = 1e-12 savingtimes = 0.5 -function test_continuous_callback(cb, g, dg!; only_backsolve=false) - function fiip(du, u, p, t) - du[1] = u[2] - du[2] = -p[1] - end - function foop(u, p, t) - dx = u[2] - dy = -p[1] - [dx, dy] - end - - u0 = [5.0, 0.0] - tspan = (0.0, 2.5) - p = [9.8, 0.8] - - prob = ODEProblem(fiip, u0, tspan, p) - proboop = ODEProblem(fiip, u0, tspan, p) - - sol1 = solve(prob, Tsit5(), u0=u0, p=p, callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes) - sol2 = solve(prob, Tsit5(), u0=u0, p=p, abstol=abstol, reltol=reltol, saveat=savingtimes) - - if cb.save_positions == [1, 1] - @test length(sol1.t) != length(sol2.t) - else - @test length(sol1.t) == length(sol2.t) - end - - du01, dp1 = @time Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=BacksolveAdjoint())), - u0, p) - - du01b, dp1b = Zygote.gradient( - (u0, p) -> g(solve(proboop, Tsit5(), u0=u0, p=p, callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=BacksolveAdjoint())), - u0, p) - - du01c, dp1c = Zygote.gradient( - (u0, p) -> g(solve(proboop, Tsit5(), u0=u0, p=p, callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=BacksolveAdjoint(checkpointing=false))), - u0, p) - - if !only_backsolve - @test_broken du02, dp2 = @time Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=ReverseDiffAdjoint())), u0, p) - - du03, dp3 = @time Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=InterpolatingAdjoint(checkpointing=true))), - u0, p) - - du03c, dp3c = Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=InterpolatingAdjoint(checkpointing=false))), - u0, p) - - du04, dp4 = @time Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=QuadratureAdjoint())), - u0, p) - end - dstuff = @time ForwardDiff.gradient( - (θ) -> g(solve(prob, Tsit5(), u0=θ[1:2], p=θ[3:4], callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes)), - [u0; p]) - - @info dstuff - - @test du01 ≈ dstuff[1:2] - @test dp1 ≈ dstuff[3:4] - @test du01b ≈ dstuff[1:2] - @test dp1b ≈ dstuff[3:4] - @test du01c ≈ dstuff[1:2] - @test dp1c ≈ dstuff[3:4] - if !only_backsolve - @test_broken du01 ≈ du02 - @test du01 ≈ du03 rtol = 1e-7 - @test du01 ≈ du03c rtol = 1e-7 - @test du03 ≈ du03c - @test du01 ≈ du04 - @test_broken dp1 ≈ dp2 - @test dp1 ≈ dp3 - @test dp1 ≈ dp3c - @test dp3 ≈ dp3c - @test dp1 ≈ dp4 rtol = 1e-7 - - @test_broken du02 ≈ dstuff[1:2] - @test_broken dp2 ≈ dstuff[3:4] - end - - cb2 = SciMLSensitivity.track_callbacks(CallbackSet(cb), prob.tspan[1], prob.u0, prob.p, BacksolveAdjoint(autojacvec=ReverseDiffVJP())) - sol_track = solve(prob, Tsit5(), u0=u0, p=p, callback=cb2, abstol=abstol, reltol=reltol, saveat=savingtimes) - - adj_prob = ODEAdjointProblem(sol_track, BacksolveAdjoint(autojacvec=ReverseDiffVJP()), sol_track.t, dg!, - callback=cb2, - abstol=abstol, reltol=reltol) - adj_sol = solve(adj_prob, Tsit5(), abstol=abstol, reltol=reltol) - @test du01 ≈ adj_sol[1:2, end] - @test dp1 ≈ adj_sol[3:4, end] - -end - -println("Continuous Callbacks") -@testset "Continuous callbacks" begin - @testset "simple loss function bouncing ball" begin - g(sol) = sum(sol) - function dg!(out, u, p, t, i) - (out .= 1) - end - - @testset "callbacks with no effect" begin - condition(u, t, integrator) = u[1] # Event when event_f(u,t) == 0 - affect!(integrator) = (integrator.u[2] += 0) - cb = ContinuousCallback(condition, affect!, save_positions=(false, false)) - test_continuous_callback(cb, g, dg!) - end - @testset "callbacks with no effect except saving the state" begin - condition(u, t, integrator) = u[1] - affect!(integrator) = (integrator.u[2] += 0) - cb = ContinuousCallback(condition, affect!, save_positions=(true, true)) - test_continuous_callback(cb, g, dg!) - end - @testset "+= callback" begin - condition(u, t, integrator) = u[1] - affect!(integrator) = (integrator.u[2] += 50.0) - cb = ContinuousCallback(condition, affect!, save_positions=(true, true)) - test_continuous_callback(cb, g, dg!) - end - @testset "= callback with parameter dependence and save" begin - condition(u, t, integrator) = u[1] - affect!(integrator) = (integrator.u[2] = -integrator.p[2] * integrator.u[2]) - cb = ContinuousCallback(condition, affect!, save_positions=(true, true)) - test_continuous_callback(cb, g, dg!) - end - @testset "= callback with parameter dependence but without save" begin - condition(u, t, integrator) = u[1] - affect!(integrator) = (integrator.u[2] = -integrator.p[2] * integrator.u[2]) - cb = ContinuousCallback(condition, affect!, save_positions=(false, false)) - test_continuous_callback(cb, g, dg!; only_backsolve=true) - end - @testset "= callback with non-linear affect" begin - condition(u, t, integrator) = u[1] - affect!(integrator) = (integrator.u[2] = integrator.u[2]^2) - cb = ContinuousCallback(condition, affect!, save_positions=(true, true)) - test_continuous_callback(cb, g, dg!) - end - @testset "= callback with terminate" begin - condition(u, t, integrator) = u[1] - affect!(integrator) = (integrator.u[2] = -integrator.p[2] * integrator.u[2]; terminate!(integrator)) - cb = ContinuousCallback(condition, affect!, save_positions=(true, true)) - test_continuous_callback(cb, g, dg!; only_backsolve=true) - end - end - @testset "MSE loss function bouncing-ball like" begin - g(u) = sum((1.0 .- u) .^ 2) ./ 2 - dg!(out, u, p, t, i) = (out .= -1.0 .+ u) - condition(u, t, integrator) = u[1] - @testset "callback with non-linear affect" begin - function affect!(integrator) - integrator.u[1] += 3.0 - integrator.u[2] = integrator.u[2]^2 - end - cb = ContinuousCallback(condition, affect!, save_positions=(true, true)) - test_continuous_callback(cb, g, dg!) - end - @testset "callback with non-linear affect and terminate" begin - function affect!(integrator) - integrator.u[1] += 3.0 - integrator.u[2] = integrator.u[2]^2 - terminate!(integrator) - end - cb = ContinuousCallback(condition, affect!, save_positions=(true, true)) - test_continuous_callback(cb, g, dg!; only_backsolve=true) - end - end - @testset "MSE loss function free particle" begin - g(u) = sum((1.0 .- u) .^ 2) ./ 2 +function test_continuous_callback(cb, g, dg!; only_backsolve = false) function fiip(du, u, p, t) - du[1] = u[2] - du[2] = 0 + du[1] = u[2] + du[2] = -p[1] end function foop(u, p, t) - dx = u[2] - dy = 0 - [dx, dy] + dx = u[2] + dy = -p[1] + [dx, dy] end - u0 = [5.0, -1.0] - p = [0.0, 0.0] - tspan = (0.0, 2.0) + u0 = [5.0, 0.0] + tspan = (0.0, 2.5) + p = [9.8, 0.8] prob = ODEProblem(fiip, u0, tspan, p) proboop = ODEProblem(fiip, u0, tspan, p) - condition(u, t, integrator) = u[1] # Event when event_f(u,t) == 0 - affect!(integrator) = (integrator.u[2] = -integrator.u[2]) - cb = ContinuousCallback(condition, affect!) + sol1 = solve(prob, Tsit5(), u0 = u0, p = p, callback = cb, abstol = abstol, + reltol = reltol, saveat = savingtimes) + sol2 = solve(prob, Tsit5(), u0 = u0, p = p, abstol = abstol, reltol = reltol, + saveat = savingtimes) - du01, dp1 = Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=BacksolveAdjoint())), - u0, p) + if cb.save_positions == [1, 1] + @test length(sol1.t) != length(sol2.t) + else + @test length(sol1.t) == length(sol2.t) + end - dstuff = @time ForwardDiff.gradient( - (θ) -> g(solve(prob, Tsit5(), u0=θ[1:2], p=θ[3:4], callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes)), - [u0; p]) + du01, dp1 = @time Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, abstol = abstol, + reltol = reltol, + saveat = savingtimes, + sensealg = BacksolveAdjoint())), + u0, p) + + du01b, dp1b = Zygote.gradient((u0, p) -> g(solve(proboop, Tsit5(), u0 = u0, p = p, + callback = cb, abstol = abstol, + reltol = reltol, saveat = savingtimes, + sensealg = BacksolveAdjoint())), + u0, p) + + du01c, dp1c = Zygote.gradient((u0, p) -> g(solve(proboop, Tsit5(), u0 = u0, p = p, + callback = cb, abstol = abstol, + reltol = reltol, saveat = savingtimes, + sensealg = BacksolveAdjoint(checkpointing = false))), + u0, p) + + if !only_backsolve + @test_broken du02, dp2 = @time Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), + u0 = u0, p = p, + callback = cb, + abstol = abstol, + reltol = reltol, + saveat = savingtimes, + sensealg = ReverseDiffAdjoint())), + u0, p) + + du03, dp3 = @time Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, abstol = abstol, + reltol = reltol, + saveat = savingtimes, + sensealg = InterpolatingAdjoint(checkpointing = true))), + u0, p) + + du03c, dp3c = Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, abstol = abstol, + reltol = reltol, + saveat = savingtimes, + sensealg = InterpolatingAdjoint(checkpointing = false))), + u0, p) + + du04, dp4 = @time Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, abstol = abstol, + reltol = reltol, + saveat = savingtimes, + sensealg = QuadratureAdjoint())), + u0, p) + end + dstuff = @time ForwardDiff.gradient((θ) -> g(solve(prob, Tsit5(), u0 = θ[1:2], + p = θ[3:4], callback = cb, + abstol = abstol, reltol = reltol, + saveat = savingtimes)), + [u0; p]) @info dstuff @test du01 ≈ dstuff[1:2] @test dp1 ≈ dstuff[3:4] - end + @test du01b ≈ dstuff[1:2] + @test dp1b ≈ dstuff[3:4] + @test du01c ≈ dstuff[1:2] + @test dp1c ≈ dstuff[3:4] + if !only_backsolve + @test_broken du01 ≈ du02 + @test du01≈du03 rtol=1e-7 + @test du01≈du03c rtol=1e-7 + @test du03 ≈ du03c + @test du01 ≈ du04 + @test_broken dp1 ≈ dp2 + @test dp1 ≈ dp3 + @test dp1 ≈ dp3c + @test dp3 ≈ dp3c + @test dp1≈dp4 rtol=1e-7 + + @test_broken du02 ≈ dstuff[1:2] + @test_broken dp2 ≈ dstuff[3:4] + end + + cb2 = SciMLSensitivity.track_callbacks(CallbackSet(cb), prob.tspan[1], prob.u0, prob.p, + BacksolveAdjoint(autojacvec = ReverseDiffVJP())) + sol_track = solve(prob, Tsit5(), u0 = u0, p = p, callback = cb2, abstol = abstol, + reltol = reltol, saveat = savingtimes) + + adj_prob = ODEAdjointProblem(sol_track, BacksolveAdjoint(autojacvec = ReverseDiffVJP()), + sol_track.t, dg!, + callback = cb2, + abstol = abstol, reltol = reltol) + adj_sol = solve(adj_prob, Tsit5(), abstol = abstol, reltol = reltol) + @test du01 ≈ adj_sol[1:2, end] + @test dp1 ≈ adj_sol[3:4, end] +end + +println("Continuous Callbacks") +@testset "Continuous callbacks" begin + @testset "simple loss function bouncing ball" begin + g(sol) = sum(sol) + function dg!(out, u, p, t, i) + (out .= 1) + end + + @testset "callbacks with no effect" begin + condition(u, t, integrator) = u[1] # Event when event_f(u,t) == 0 + affect!(integrator) = (integrator.u[2] += 0) + cb = ContinuousCallback(condition, affect!, save_positions = (false, false)) + test_continuous_callback(cb, g, dg!) + end + @testset "callbacks with no effect except saving the state" begin + condition(u, t, integrator) = u[1] + affect!(integrator) = (integrator.u[2] += 0) + cb = ContinuousCallback(condition, affect!, save_positions = (true, true)) + test_continuous_callback(cb, g, dg!) + end + @testset "+= callback" begin + condition(u, t, integrator) = u[1] + affect!(integrator) = (integrator.u[2] += 50.0) + cb = ContinuousCallback(condition, affect!, save_positions = (true, true)) + test_continuous_callback(cb, g, dg!) + end + @testset "= callback with parameter dependence and save" begin + condition(u, t, integrator) = u[1] + affect!(integrator) = (integrator.u[2] = -integrator.p[2] * integrator.u[2]) + cb = ContinuousCallback(condition, affect!, save_positions = (true, true)) + test_continuous_callback(cb, g, dg!) + end + @testset "= callback with parameter dependence but without save" begin + condition(u, t, integrator) = u[1] + affect!(integrator) = (integrator.u[2] = -integrator.p[2] * integrator.u[2]) + cb = ContinuousCallback(condition, affect!, save_positions = (false, false)) + test_continuous_callback(cb, g, dg!; only_backsolve = true) + end + @testset "= callback with non-linear affect" begin + condition(u, t, integrator) = u[1] + affect!(integrator) = (integrator.u[2] = integrator.u[2]^2) + cb = ContinuousCallback(condition, affect!, save_positions = (true, true)) + test_continuous_callback(cb, g, dg!) + end + @testset "= callback with terminate" begin + condition(u, t, integrator) = u[1] + function affect!(integrator) + (integrator.u[2] = -integrator.p[2] * integrator.u[2]; terminate!(integrator)) + end + cb = ContinuousCallback(condition, affect!, save_positions = (true, true)) + test_continuous_callback(cb, g, dg!; only_backsolve = true) + end + end + @testset "MSE loss function bouncing-ball like" begin + g(u) = sum((1.0 .- u) .^ 2) ./ 2 + dg!(out, u, p, t, i) = (out .= -1.0 .+ u) + condition(u, t, integrator) = u[1] + @testset "callback with non-linear affect" begin + function affect!(integrator) + integrator.u[1] += 3.0 + integrator.u[2] = integrator.u[2]^2 + end + cb = ContinuousCallback(condition, affect!, save_positions = (true, true)) + test_continuous_callback(cb, g, dg!) + end + @testset "callback with non-linear affect and terminate" begin + function affect!(integrator) + integrator.u[1] += 3.0 + integrator.u[2] = integrator.u[2]^2 + terminate!(integrator) + end + cb = ContinuousCallback(condition, affect!, save_positions = (true, true)) + test_continuous_callback(cb, g, dg!; only_backsolve = true) + end + end + @testset "MSE loss function free particle" begin + g(u) = sum((1.0 .- u) .^ 2) ./ 2 + function fiip(du, u, p, t) + du[1] = u[2] + du[2] = 0 + end + function foop(u, p, t) + dx = u[2] + dy = 0 + [dx, dy] + end + + u0 = [5.0, -1.0] + p = [0.0, 0.0] + tspan = (0.0, 2.0) + + prob = ODEProblem(fiip, u0, tspan, p) + proboop = ODEProblem(fiip, u0, tspan, p) + + condition(u, t, integrator) = u[1] # Event when event_f(u,t) == 0 + affect!(integrator) = (integrator.u[2] = -integrator.u[2]) + cb = ContinuousCallback(condition, affect!) + + du01, dp1 = Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, abstol = abstol, + reltol = reltol, + saveat = savingtimes, + sensealg = BacksolveAdjoint())), + u0, p) + + dstuff = @time ForwardDiff.gradient((θ) -> g(solve(prob, Tsit5(), u0 = θ[1:2], + p = θ[3:4], callback = cb, + abstol = abstol, reltol = reltol, + saveat = savingtimes)), + [u0; p]) + + @info dstuff + + @test du01 ≈ dstuff[1:2] + @test dp1 ≈ dstuff[3:4] + end end diff --git a/test/callbacks/continuous_vs_discrete.jl b/test/callbacks/continuous_vs_discrete.jl index fdab41f13..693031660 100644 --- a/test/callbacks/continuous_vs_discrete.jl +++ b/test/callbacks/continuous_vs_discrete.jl @@ -6,156 +6,155 @@ reltol = 1e-12 savingtimes = 0.5 function test_continuous_wrt_discrete_callback() - # test the continuous callbacks wrt to the equivalent discrete callback - function f(du, u, p, t) - #Bouncing Ball - du[1] = u[2] - du[2] = -p[1] - end - - # no saving in Callbacks; prescribed vafter and vbefore; loss on the endpoint - - tstop = 3.1943828249997 - vbefore = -31.30495168499705 - vafter = 25.04396134799764 - - u0 = [50.0, 0.0] - tspan = (0.0, 5.0) - p = [9.8, 0.8] - - prob = ODEProblem(f, u0, tspan, p) - - function condition(u, t, integrator) # Event when event_f(u,t) == 0 - t - tstop - end - function affect!(integrator) - integrator.u[2] += vafter - vbefore - end - cb = ContinuousCallback(condition, affect!, save_positions=(false, false)) - - - condition2(u, t, integrator) = t == tstop - cb2 = DiscreteCallback(condition2, affect!, save_positions=(false, false)) - - - du01, dp1 = Zygote.gradient( - (u0, p) -> sum(solve(prob, Tsit5(), u0=u0, p=p, - callback=cb2, tstops=[tstop], - sensealg=BacksolveAdjoint(), - saveat=tspan[2], save_start=false)), u0, p) - - du02, dp2 = Zygote.gradient( - (u0, p) -> sum(solve(prob, Tsit5(), u0=u0, p=p, - callback=cb, - sensealg=BacksolveAdjoint(), - saveat=tspan[2], save_start=false)), u0, p) - - dstuff = ForwardDiff.gradient((θ) -> sum(solve(prob, Tsit5(), u0=θ[1:2], p=θ[3:4], - callback=cb, saveat=tspan[2], save_start=false)), [u0; p]) - - @info dstuff - @test du01 ≈ dstuff[1:2] - @test dp1 ≈ dstuff[3:4] - @test du02 ≈ dstuff[1:2] - @test dp2 ≈ dstuff[3:4] - - # no saving in Callbacks; prescribed vafter and vbefore; loss on the endpoint by slicing - du01, dp1 = Zygote.gradient( - (u0, p) -> sum(solve(prob, Tsit5(), u0=u0, p=p, - callback=cb2, tstops=[tstop], - sensealg=BacksolveAdjoint())[end]), u0, p) - - du02, dp2 = Zygote.gradient( - (u0, p) -> sum(solve(prob, Tsit5(), u0=u0, p=p, - callback=cb, - sensealg=BacksolveAdjoint())[end]), u0, p) - - dstuff = ForwardDiff.gradient((θ) -> sum(solve(prob, Tsit5(), u0=θ[1:2], p=θ[3:4], - callback=cb)[end]), [u0; p]) - - @info dstuff - @test du01 ≈ dstuff[1:2] - @test dp1 ≈ dstuff[3:4] - @test du02 ≈ dstuff[1:2] - @test dp2 ≈ dstuff[3:4] - - # with saving in Callbacks; prescribed vafter and vbefore; loss on the endpoint - cb = ContinuousCallback(condition, affect!, save_positions=(true, true)) - cb2 = DiscreteCallback(condition2, affect!, save_positions=(true, true)) - - du01, dp1 = Zygote.gradient( - (u0, p) -> sum(solve(prob, Tsit5(), u0=u0, p=p, - callback=cb2, tstops=[tstop], - sensealg=BacksolveAdjoint(), - saveat=tspan[2], save_start=false)), u0, p) - - du02, dp2 = Zygote.gradient( - (u0, p) -> sum(solve(prob, Tsit5(), u0=u0, p=p, - callback=cb, - sensealg=BacksolveAdjoint(), - saveat=tspan[2], save_start=false)), u0, p) - - dstuff = ForwardDiff.gradient((θ) -> sum(solve(prob, Tsit5(), u0=θ[1:2], p=θ[3:4], - callback=cb, saveat=tspan[2], save_start=false)), [u0; p]) - - @info dstuff - @test du01 ≈ dstuff[1:2] - @test dp1 ≈ dstuff[3:4] - @test du02 ≈ dstuff[1:2] - @test dp2 ≈ dstuff[3:4] - - # with saving in Callbacks; prescribed vafter and vbefore; loss on the endpoint by slicing - du01, dp1 = Zygote.gradient( - (u0, p) -> sum(solve(prob, Tsit5(), u0=u0, p=p, - callback=cb2, tstops=[tstop], - sensealg=BacksolveAdjoint())[end]), u0, p) - - du02, dp2 = Zygote.gradient( - (u0, p) -> sum(solve(prob, Tsit5(), u0=u0, p=p, - callback=cb, - sensealg=BacksolveAdjoint())[end]), u0, p) - - dstuff = ForwardDiff.gradient((θ) -> sum(solve(prob, Tsit5(), u0=θ[1:2], p=θ[3:4], - callback=cb)[end]), [u0; p]) - - @info dstuff - @test du01 ≈ dstuff[1:2] - @test dp1 ≈ dstuff[3:4] - @test du02 ≈ dstuff[1:2] - @test dp2 ≈ dstuff[3:4] - - # with saving in Callbacks; different affect function - function affect2!(integrator) - integrator.u[2] = -integrator.p[2] * integrator.u[2] - end - cb = ContinuousCallback(condition, affect2!, save_positions=(true, true)) - - cb2 = DiscreteCallback(condition2, affect2!, save_positions=(true, true)) - - du01, dp1 = Zygote.gradient( - (u0, p) -> sum(solve(prob, Tsit5(), u0=u0, p=p, - callback=cb2, tstops=[tstop], - sensealg=BacksolveAdjoint(), - saveat=tspan[2], save_start=false)), u0, p) - - du02, dp2 = Zygote.gradient( - (u0, p) -> sum(solve(prob, Tsit5(), u0=u0, p=p, - callback=cb, - sensealg=BacksolveAdjoint(), - saveat=tspan[2], save_start=false)), u0, p) - - dstuff = ForwardDiff.gradient((θ) -> sum(solve(prob, Tsit5(), u0=θ[1:2], p=θ[3:4], - callback=cb, saveat=tspan[2], save_start=false)), [u0; p]) - - @info dstuff - @test du01 ≈ dstuff[1:2] - @test dp1 ≈ dstuff[3:4] - @test du02 ≈ dstuff[1:2] - @test dp2 ≈ dstuff[3:4] - @test du01 ≈ du02 - @test dp1 ≈ dp2 + # test the continuous callbacks wrt to the equivalent discrete callback + function f(du, u, p, t) + #Bouncing Ball + du[1] = u[2] + du[2] = -p[1] + end + + # no saving in Callbacks; prescribed vafter and vbefore; loss on the endpoint + + tstop = 3.1943828249997 + vbefore = -31.30495168499705 + vafter = 25.04396134799764 + + u0 = [50.0, 0.0] + tspan = (0.0, 5.0) + p = [9.8, 0.8] + + prob = ODEProblem(f, u0, tspan, p) + + function condition(u, t, integrator) # Event when event_f(u,t) == 0 + t - tstop + end + function affect!(integrator) + integrator.u[2] += vafter - vbefore + end + cb = ContinuousCallback(condition, affect!, save_positions = (false, false)) + + condition2(u, t, integrator) = t == tstop + cb2 = DiscreteCallback(condition2, affect!, save_positions = (false, false)) + + du01, dp1 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb2, tstops = [tstop], + sensealg = BacksolveAdjoint(), + saveat = tspan[2], save_start = false)), + u0, p) + + du02, dp2 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, + sensealg = BacksolveAdjoint(), + saveat = tspan[2], save_start = false)), + u0, p) + + dstuff = ForwardDiff.gradient((θ) -> sum(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:4], + callback = cb, saveat = tspan[2], + save_start = false)), [u0; p]) + + @info dstuff + @test du01 ≈ dstuff[1:2] + @test dp1 ≈ dstuff[3:4] + @test du02 ≈ dstuff[1:2] + @test dp2 ≈ dstuff[3:4] + + # no saving in Callbacks; prescribed vafter and vbefore; loss on the endpoint by slicing + du01, dp1 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb2, tstops = [tstop], + sensealg = BacksolveAdjoint())[end]), + u0, p) + + du02, dp2 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, + sensealg = BacksolveAdjoint())[end]), + u0, p) + + dstuff = ForwardDiff.gradient((θ) -> sum(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:4], + callback = cb)[end]), [u0; p]) + + @info dstuff + @test du01 ≈ dstuff[1:2] + @test dp1 ≈ dstuff[3:4] + @test du02 ≈ dstuff[1:2] + @test dp2 ≈ dstuff[3:4] + + # with saving in Callbacks; prescribed vafter and vbefore; loss on the endpoint + cb = ContinuousCallback(condition, affect!, save_positions = (true, true)) + cb2 = DiscreteCallback(condition2, affect!, save_positions = (true, true)) + + du01, dp1 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb2, tstops = [tstop], + sensealg = BacksolveAdjoint(), + saveat = tspan[2], save_start = false)), + u0, p) + + du02, dp2 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, + sensealg = BacksolveAdjoint(), + saveat = tspan[2], save_start = false)), + u0, p) + + dstuff = ForwardDiff.gradient((θ) -> sum(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:4], + callback = cb, saveat = tspan[2], + save_start = false)), [u0; p]) + + @info dstuff + @test du01 ≈ dstuff[1:2] + @test dp1 ≈ dstuff[3:4] + @test du02 ≈ dstuff[1:2] + @test dp2 ≈ dstuff[3:4] + + # with saving in Callbacks; prescribed vafter and vbefore; loss on the endpoint by slicing + du01, dp1 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb2, tstops = [tstop], + sensealg = BacksolveAdjoint())[end]), + u0, p) + + du02, dp2 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, + sensealg = BacksolveAdjoint())[end]), + u0, p) + + dstuff = ForwardDiff.gradient((θ) -> sum(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:4], + callback = cb)[end]), [u0; p]) + + @info dstuff + @test du01 ≈ dstuff[1:2] + @test dp1 ≈ dstuff[3:4] + @test du02 ≈ dstuff[1:2] + @test dp2 ≈ dstuff[3:4] + + # with saving in Callbacks; different affect function + function affect2!(integrator) + integrator.u[2] = -integrator.p[2] * integrator.u[2] + end + cb = ContinuousCallback(condition, affect2!, save_positions = (true, true)) + + cb2 = DiscreteCallback(condition2, affect2!, save_positions = (true, true)) + + du01, dp1 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb2, tstops = [tstop], + sensealg = BacksolveAdjoint(), + saveat = tspan[2], save_start = false)), + u0, p) + + du02, dp2 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, + sensealg = BacksolveAdjoint(), + saveat = tspan[2], save_start = false)), + u0, p) + + dstuff = ForwardDiff.gradient((θ) -> sum(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:4], + callback = cb, saveat = tspan[2], + save_start = false)), [u0; p]) + + @info dstuff + @test du01 ≈ dstuff[1:2] + @test dp1 ≈ dstuff[3:4] + @test du02 ≈ dstuff[1:2] + @test dp2 ≈ dstuff[3:4] + @test du01 ≈ du02 + @test dp1 ≈ dp2 end -@testset "Compare continuous with discrete callbacks" begin - test_continuous_wrt_discrete_callback() -end \ No newline at end of file +@testset "Compare continuous with discrete callbacks" begin test_continuous_wrt_discrete_callback() end diff --git a/test/callbacks/discrete_callbacks.jl b/test/callbacks/discrete_callbacks.jl index 4b4ae55a9..cb5823208 100644 --- a/test/callbacks/discrete_callbacks.jl +++ b/test/callbacks/discrete_callbacks.jl @@ -5,243 +5,281 @@ abstol = 1e-12 reltol = 1e-12 savingtimes = 0.5 -function test_discrete_callback(cb, tstops, g, dg!, cboop=nothing, tprev=false) - function fiip(du, u, p, t) - du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] - du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2] - end - function foop(u, p, t) - dx = p[1] * u[1] - p[2] * u[1] * u[2] - dy = -p[3] * u[2] + p[4] * u[1] * u[2] - [dx, dy] - end - - p = [1.5, 1.0, 3.0, 1.0] - u0 = [1.0; 1.0] - - prob = ODEProblem(fiip, u0, (0.0, 10.0), p) - proboop = ODEProblem(foop, u0, (0.0, 10.0), p) - - sol1 = solve(prob, Tsit5(), u0=u0, p=p, callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes) - sol2 = solve(prob, Tsit5(), u0=u0, p=p, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes) - - if cb.save_positions == [1, 1] - @test length(sol1.t) != length(sol2.t) - else - @test length(sol1.t) == length(sol2.t) - end - - du01, dp1 = Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=BacksolveAdjoint())), - u0, p) - - du01b, dp1b = Zygote.gradient( - (u0, p) -> g(solve(proboop, Tsit5(), u0=u0, p=p, callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=BacksolveAdjoint())), - u0, p) - - du01c, dp1c = Zygote.gradient( - (u0, p) -> g(solve(proboop, Tsit5(), u0=u0, p=p, callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=BacksolveAdjoint(checkpointing=false))), - u0, p) - - if cboop === nothing - du02, dp2 = Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=ReverseDiffAdjoint())), u0, p) - else - du02, dp2 = Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cboop, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=ReverseDiffAdjoint())), u0, p) - end - - du03, dp3 = Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=InterpolatingAdjoint(checkpointing=true))), - u0, p) - - du03c, dp3c = Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=InterpolatingAdjoint(checkpointing=false))), - u0, p) - - du04, dp4 = Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=QuadratureAdjoint())), - u0, p) - - dstuff = ForwardDiff.gradient( - (θ) -> g(solve(prob, Tsit5(), u0=θ[1:2], p=θ[3:6], callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes)), - [u0; p]) - - @info dstuff - - # tests wrt discrete sensitivities - if tprev - # tprev depends on stepping behaviour of integrator. Thus sensitivities are necessarily (slightly) different. - @test du02 ≈ dstuff[1:2] rtol = 1e-3 - @test dp2 ≈ dstuff[3:6] rtol = 1e-3 - @test du01 ≈ dstuff[1:2] rtol = 1e-3 - @test dp1 ≈ dstuff[3:6] rtol = 1e-3 - @test du01 ≈ du02 rtol = 1e-3 - @test dp1 ≈ dp2 rtol = 1e-3 - else - @test du02 ≈ dstuff[1:2] - @test dp2 ≈ dstuff[3:6] - @test du01 ≈ dstuff[1:2] - @test dp1 ≈ dstuff[3:6] - @test du01 ≈ du02 - @test dp1 ≈ dp2 - end - - # tests wrt continuous sensitivities - @test du01b ≈ du01 - @test dp1b ≈ dp1 - @test du01c ≈ du01 - @test dp1c ≈ dp1 - @test du01 ≈ du03 rtol = 1e-7 - @test du01 ≈ du03c rtol = 1e-7 - @test du03 ≈ du03c - @test du01 ≈ du04 - @test dp1 ≈ dp3 - @test dp1 ≈ dp3c - @test dp1 ≈ dp4 rtol = 1e-7 - - cb2 = SciMLSensitivity.track_callbacks(CallbackSet(cb), prob.tspan[1], prob.u0, prob.p, BacksolveAdjoint(autojacvec=ReverseDiffVJP())) - sol_track = solve(prob, Tsit5(), u0=u0, p=p, callback=cb2, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes) - #cb_adj = SciMLSensitivity.setup_reverse_callbacks(cb2,BacksolveAdjoint()) - - adj_prob = ODEAdjointProblem(sol_track, BacksolveAdjoint(autojacvec=ReverseDiffVJP()), sol_track.t, dg!, - callback=cb2, - abstol=abstol, reltol=reltol) - adj_sol = solve(adj_prob, Tsit5(), abstol=abstol, reltol=reltol) - @test du01 ≈ adj_sol[1:2, end] - @test dp1 ≈ adj_sol[3:6, end] +function test_discrete_callback(cb, tstops, g, dg!, cboop = nothing, tprev = false) + function fiip(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2] + end + function foop(u, p, t) + dx = p[1] * u[1] - p[2] * u[1] * u[2] + dy = -p[3] * u[2] + p[4] * u[1] * u[2] + [dx, dy] + end + + p = [1.5, 1.0, 3.0, 1.0] + u0 = [1.0; 1.0] + + prob = ODEProblem(fiip, u0, (0.0, 10.0), p) + proboop = ODEProblem(foop, u0, (0.0, 10.0), p) + + sol1 = solve(prob, Tsit5(), u0 = u0, p = p, callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, saveat = savingtimes) + sol2 = solve(prob, Tsit5(), u0 = u0, p = p, tstops = tstops, abstol = abstol, + reltol = reltol, saveat = savingtimes) + + if cb.save_positions == [1, 1] + @test length(sol1.t) != length(sol2.t) + else + @test length(sol1.t) == length(sol2.t) + end + + du01, dp1 = Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes, + sensealg = BacksolveAdjoint())), + u0, p) + + du01b, dp1b = Zygote.gradient((u0, p) -> g(solve(proboop, Tsit5(), u0 = u0, p = p, + callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes, + sensealg = BacksolveAdjoint())), + u0, p) + + du01c, dp1c = Zygote.gradient((u0, p) -> g(solve(proboop, Tsit5(), u0 = u0, p = p, + callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes, + sensealg = BacksolveAdjoint(checkpointing = false))), + u0, p) + + if cboop === nothing + du02, dp2 = Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes, + sensealg = ReverseDiffAdjoint())), + u0, p) + else + du02, dp2 = Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cboop, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes, + sensealg = ReverseDiffAdjoint())), + u0, p) + end + + du03, dp3 = Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes, + sensealg = InterpolatingAdjoint(checkpointing = true))), + u0, p) + + du03c, dp3c = Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes, + sensealg = InterpolatingAdjoint(checkpointing = false))), + u0, p) + + du04, dp4 = Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes, + sensealg = QuadratureAdjoint())), + u0, p) + + dstuff = ForwardDiff.gradient((θ) -> g(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:6], + callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes)), + [u0; p]) + + @info dstuff + # tests wrt discrete sensitivities + if tprev + # tprev depends on stepping behaviour of integrator. Thus sensitivities are necessarily (slightly) different. + @test du02≈dstuff[1:2] rtol=1e-3 + @test dp2≈dstuff[3:6] rtol=1e-3 + @test du01≈dstuff[1:2] rtol=1e-3 + @test dp1≈dstuff[3:6] rtol=1e-3 + @test du01≈du02 rtol=1e-3 + @test dp1≈dp2 rtol=1e-3 + else + @test du02 ≈ dstuff[1:2] + @test dp2 ≈ dstuff[3:6] + @test du01 ≈ dstuff[1:2] + @test dp1 ≈ dstuff[3:6] + @test du01 ≈ du02 + @test dp1 ≈ dp2 + end + + # tests wrt continuous sensitivities + @test du01b ≈ du01 + @test dp1b ≈ dp1 + @test du01c ≈ du01 + @test dp1c ≈ dp1 + @test du01≈du03 rtol=1e-7 + @test du01≈du03c rtol=1e-7 + @test du03 ≈ du03c + @test du01 ≈ du04 + @test dp1 ≈ dp3 + @test dp1 ≈ dp3c + @test dp1≈dp4 rtol=1e-7 + + cb2 = SciMLSensitivity.track_callbacks(CallbackSet(cb), prob.tspan[1], prob.u0, prob.p, + BacksolveAdjoint(autojacvec = ReverseDiffVJP())) + sol_track = solve(prob, Tsit5(), u0 = u0, p = p, callback = cb2, tstops = tstops, + abstol = abstol, reltol = reltol, saveat = savingtimes) + #cb_adj = SciMLSensitivity.setup_reverse_callbacks(cb2,BacksolveAdjoint()) + + adj_prob = ODEAdjointProblem(sol_track, BacksolveAdjoint(autojacvec = ReverseDiffVJP()), + sol_track.t, dg!, + callback = cb2, + abstol = abstol, reltol = reltol) + adj_sol = solve(adj_prob, Tsit5(), abstol = abstol, reltol = reltol) + @test du01 ≈ adj_sol[1:2, end] + @test dp1 ≈ adj_sol[3:6, end] end -@testset "Discrete callbacks" begin - @testset "ODEs" begin +@testset "Discrete callbacks" begin @testset "ODEs" begin println("ODEs") @testset "simple loss function" begin - g(sol) = sum(sol) - function dg!(out, u, p, t, i) - (out .= 1) - end - @testset "callbacks with no effect" begin - condition(u, t, integrator) = t == 5 - affect!(integrator) = integrator.u[1] += 0.0 - cb = DiscreteCallback(condition, affect!, save_positions=(false, false)) - tstops = [5.0] - test_discrete_callback(cb, tstops, g, dg!) - end - @testset "callbacks with no effect except saving the state" begin - condition(u, t, integrator) = t == 5 - affect!(integrator) = integrator.u[1] += 0.0 - cb = DiscreteCallback(condition, affect!) - tstops = [5.0] - test_discrete_callback(cb, tstops, g, dg!) - end - @testset "callback at single time point" begin - condition(u, t, integrator) = t == 5 - affect!(integrator) = integrator.u[1] += 2.0 - cb = DiscreteCallback(condition, affect!) - tstops = [5.0] - test_discrete_callback(cb, tstops, g, dg!) - end - @testset "callback at multiple time points" begin - affecttimes = [2.03, 4.0, 8.0] - condition(u, t, integrator) = t ∈ affecttimes - affect!(integrator) = integrator.u[1] += 2.0 - cb = DiscreteCallback(condition, affect!) - test_discrete_callback(cb, affecttimes, g, dg!) - end - @testset "state-dependent += callback at single time point" begin - condition(u, t, integrator) = t == 5 - affect!(integrator) = (integrator.u .+= integrator.p[2] / 8 * sin.(integrator.u)) - cb = DiscreteCallback(condition, affect!) - tstops = [5.0] - test_discrete_callback(cb, tstops, g, dg!) - end - @testset "other callback at single time point" begin - condition(u, t, integrator) = t == 5 - affect!(integrator) = (integrator.u[1] = 2.0; @show "triggered!") - cb = DiscreteCallback(condition, affect!) - tstops = [5.0] - test_discrete_callback(cb, tstops, g, dg!) - end - @testset "parameter changing callback at single time point" begin - condition(u, t, integrator) = t == 5.1 - affect!(integrator) = (integrator.p .= 2 * integrator.p .- 0.5) - affect(integrator) = (integrator.p = 2 * integrator.p .- 0.5) - cb = DiscreteCallback(condition, affect!) - cboop = DiscreteCallback(condition, affect) - cb = DiscreteCallback(condition, affect!) - tstops = [5.1] - test_discrete_callback(cb, tstops, g, dg!, cboop) - end - @testset "tprev dependent callback" begin - condition(u, t, integrator) = t == 5 - affect!(integrator) = (@show integrator.tprev; integrator.u[1] += integrator.t - integrator.tprev) - cb = DiscreteCallback(condition, affect!) - tstops = [4.999, 5.0] - test_discrete_callback(cb, tstops, g, dg!, nothing, true) - end + g(sol) = sum(sol) + function dg!(out, u, p, t, i) + (out .= 1) + end + @testset "callbacks with no effect" begin + condition(u, t, integrator) = t == 5 + affect!(integrator) = integrator.u[1] += 0.0 + cb = DiscreteCallback(condition, affect!, save_positions = (false, false)) + tstops = [5.0] + test_discrete_callback(cb, tstops, g, dg!) + end + @testset "callbacks with no effect except saving the state" begin + condition(u, t, integrator) = t == 5 + affect!(integrator) = integrator.u[1] += 0.0 + cb = DiscreteCallback(condition, affect!) + tstops = [5.0] + test_discrete_callback(cb, tstops, g, dg!) + end + @testset "callback at single time point" begin + condition(u, t, integrator) = t == 5 + affect!(integrator) = integrator.u[1] += 2.0 + cb = DiscreteCallback(condition, affect!) + tstops = [5.0] + test_discrete_callback(cb, tstops, g, dg!) + end + @testset "callback at multiple time points" begin + affecttimes = [2.03, 4.0, 8.0] + condition(u, t, integrator) = t ∈ affecttimes + affect!(integrator) = integrator.u[1] += 2.0 + cb = DiscreteCallback(condition, affect!) + test_discrete_callback(cb, affecttimes, g, dg!) + end + @testset "state-dependent += callback at single time point" begin + condition(u, t, integrator) = t == 5 + function affect!(integrator) + (integrator.u .+= integrator.p[2] / 8 * sin.(integrator.u)) + end + cb = DiscreteCallback(condition, affect!) + tstops = [5.0] + test_discrete_callback(cb, tstops, g, dg!) + end + @testset "other callback at single time point" begin + condition(u, t, integrator) = t == 5 + affect!(integrator) = (integrator.u[1] = 2.0; @show "triggered!") + cb = DiscreteCallback(condition, affect!) + tstops = [5.0] + test_discrete_callback(cb, tstops, g, dg!) + end + @testset "parameter changing callback at single time point" begin + condition(u, t, integrator) = t == 5.1 + affect!(integrator) = (integrator.p .= 2 * integrator.p .- 0.5) + affect(integrator) = (integrator.p = 2 * integrator.p .- 0.5) + cb = DiscreteCallback(condition, affect!) + cboop = DiscreteCallback(condition, affect) + cb = DiscreteCallback(condition, affect!) + tstops = [5.1] + test_discrete_callback(cb, tstops, g, dg!, cboop) + end + @testset "tprev dependent callback" begin + condition(u, t, integrator) = t == 5 + function affect!(integrator) + (@show integrator.tprev; integrator.u[1] += integrator.t - integrator.tprev) + end + cb = DiscreteCallback(condition, affect!) + tstops = [4.999, 5.0] + test_discrete_callback(cb, tstops, g, dg!, nothing, true) + end end @testset "MSE loss function" begin - g(u) = sum((1.0 .- u) .^ 2) ./ 2 - dg!(out, u, p, t, i) = (out .= -1.0 .+ u) - @testset "callbacks with no effect" begin - condition(u, t, integrator) = t == 5 - affect!(integrator) = integrator.u[1] += 0.0 - cb = DiscreteCallback(condition, affect!, save_positions=(false, false)) - tstops = [5.0] - test_discrete_callback(cb, tstops, g, dg!) - end - @testset "callbacks with no effect except saving the state" begin - condition(u, t, integrator) = t == 5 - affect!(integrator) = integrator.u[1] += 0.0 - cb = DiscreteCallback(condition, affect!) - tstops = [5.0] - test_discrete_callback(cb, tstops, g, dg!) - end - @testset "callback at single time point" begin - condition(u, t, integrator) = t == 5 - affect!(integrator) = integrator.u[1] += 2.0 - cb = DiscreteCallback(condition, affect!) - tstops = [5.0] - test_discrete_callback(cb, tstops, g, dg!) - end - @testset "callback at multiple time points" begin - affecttimes = [2.03, 4.0, 8.0] - condition(u, t, integrator) = t ∈ affecttimes - affect!(integrator) = integrator.u[1] += 2.0 - cb = DiscreteCallback(condition, affect!) - test_discrete_callback(cb, affecttimes, g, dg!) - end - @testset "state-dependent += callback at single time point" begin - condition(u, t, integrator) = t == 5 - affect!(integrator) = (integrator.u .+= integrator.p[2] / 8 * sin.(integrator.u)) - cb = DiscreteCallback(condition, affect!) - tstops = [5.0] - test_discrete_callback(cb, tstops, g, dg!) - end - @testset "other callback at single time point" begin - condition(u, t, integrator) = t == 5 - affect!(integrator) = (integrator.u[1] = 2.0; @show "triggered!") - cb = DiscreteCallback(condition, affect!) - tstops = [5.0] - test_discrete_callback(cb, tstops, g, dg!) - end - @testset "parameter changing callback at single time point" begin - condition(u, t, integrator) = t == 5.1 - affect!(integrator) = (integrator.p .= 2 * integrator.p .- 0.5) - affect(integrator) = (integrator.p = 2 * integrator.p .- 0.5) - cb = DiscreteCallback(condition, affect!) - cboop = DiscreteCallback(condition, affect) - tstops = [5.1] - test_discrete_callback(cb, tstops, g, dg!, cboop) - end - @testset "tprev dependent callback" begin - condition(u, t, integrator) = t == 5 - affect!(integrator) = (@show integrator.tprev; integrator.u[1] += integrator.t - integrator.tprev) - cb = DiscreteCallback(condition, affect!) - tstops = [4.999, 5.0] - test_discrete_callback(cb, tstops, g, dg!, nothing, true) - end + g(u) = sum((1.0 .- u) .^ 2) ./ 2 + dg!(out, u, p, t, i) = (out .= -1.0 .+ u) + @testset "callbacks with no effect" begin + condition(u, t, integrator) = t == 5 + affect!(integrator) = integrator.u[1] += 0.0 + cb = DiscreteCallback(condition, affect!, save_positions = (false, false)) + tstops = [5.0] + test_discrete_callback(cb, tstops, g, dg!) + end + @testset "callbacks with no effect except saving the state" begin + condition(u, t, integrator) = t == 5 + affect!(integrator) = integrator.u[1] += 0.0 + cb = DiscreteCallback(condition, affect!) + tstops = [5.0] + test_discrete_callback(cb, tstops, g, dg!) + end + @testset "callback at single time point" begin + condition(u, t, integrator) = t == 5 + affect!(integrator) = integrator.u[1] += 2.0 + cb = DiscreteCallback(condition, affect!) + tstops = [5.0] + test_discrete_callback(cb, tstops, g, dg!) + end + @testset "callback at multiple time points" begin + affecttimes = [2.03, 4.0, 8.0] + condition(u, t, integrator) = t ∈ affecttimes + affect!(integrator) = integrator.u[1] += 2.0 + cb = DiscreteCallback(condition, affect!) + test_discrete_callback(cb, affecttimes, g, dg!) + end + @testset "state-dependent += callback at single time point" begin + condition(u, t, integrator) = t == 5 + function affect!(integrator) + (integrator.u .+= integrator.p[2] / 8 * sin.(integrator.u)) + end + cb = DiscreteCallback(condition, affect!) + tstops = [5.0] + test_discrete_callback(cb, tstops, g, dg!) + end + @testset "other callback at single time point" begin + condition(u, t, integrator) = t == 5 + affect!(integrator) = (integrator.u[1] = 2.0; @show "triggered!") + cb = DiscreteCallback(condition, affect!) + tstops = [5.0] + test_discrete_callback(cb, tstops, g, dg!) + end + @testset "parameter changing callback at single time point" begin + condition(u, t, integrator) = t == 5.1 + affect!(integrator) = (integrator.p .= 2 * integrator.p .- 0.5) + affect(integrator) = (integrator.p = 2 * integrator.p .- 0.5) + cb = DiscreteCallback(condition, affect!) + cboop = DiscreteCallback(condition, affect) + tstops = [5.1] + test_discrete_callback(cb, tstops, g, dg!, cboop) + end + @testset "tprev dependent callback" begin + condition(u, t, integrator) = t == 5 + function affect!(integrator) + (@show integrator.tprev; integrator.u[1] += integrator.t - integrator.tprev) + end + cb = DiscreteCallback(condition, affect!) + tstops = [4.999, 5.0] + test_discrete_callback(cb, tstops, g, dg!, nothing, true) + end end - end -end +end end diff --git a/test/callbacks/forward_sensitivity_callback.jl b/test/callbacks/forward_sensitivity_callback.jl index 2ec84fef8..b0a4dfd88 100644 --- a/test/callbacks/forward_sensitivity_callback.jl +++ b/test/callbacks/forward_sensitivity_callback.jl @@ -8,46 +8,57 @@ reltol = 1e-6 savingtimes = 0.1 function test_discrete_callback(cb, tstops, g) - function fiip(du, u, p, t) - #du[1] = dx = p[1]*u[1] - du[:] .= p[1]*u - end + function fiip(du, u, p, t) + #du[1] = dx = p[1]*u[1] + du[:] .= p[1] * u + end - p = Float64[0.8123198] - u0 = Float64[1.0] + p = Float64[0.8123198] + u0 = Float64[1.0] - prob = ODEProblem(fiip, u0, (0.0, 1.0), p) + prob = ODEProblem(fiip, u0, (0.0, 1.0), p) - @show g(solve(prob, Tsit5(), callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes)) + @show g(solve(prob, Tsit5(), callback = cb, tstops = tstops, abstol = abstol, + reltol = reltol, saveat = savingtimes)) - du01, dp1 = Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes, - sensealg=ForwardDiffSensitivity(;convert_tspan=true))), - u0, p) + du01, dp1 = Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes, + sensealg = ForwardDiffSensitivity(; + convert_tspan = true))), + u0, p) - dstuff1 = ForwardDiff.gradient( - (θ) -> g(solve(prob, Tsit5(), u0=θ[1:1], p=θ[2:2], callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes)), - [u0; p]) + dstuff1 = ForwardDiff.gradient((θ) -> g(solve(prob, Tsit5(), u0 = θ[1:1], p = θ[2:2], + callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes)), + [u0; p]) - dstuff2 = FiniteDiff.finite_difference_gradient( - (θ) -> g(solve(prob, Tsit5(), u0=θ[1:1], p=θ[2:2], callback=cb, tstops=tstops, abstol=abstol, reltol=reltol, saveat=savingtimes)), - [u0; p]) + dstuff2 = FiniteDiff.finite_difference_gradient((θ) -> g(solve(prob, Tsit5(), + u0 = θ[1:1], p = θ[2:2], + callback = cb, + tstops = tstops, + abstol = abstol, + reltol = reltol, + saveat = savingtimes)), + [u0; p]) - @show du01 dp1 dstuff1 dstuff2 - @test du01 ≈ dstuff1[1:1] atol=1e-6 - @test dp1 ≈ dstuff1[2:2] atol=1e-6 - @test du01 ≈ dstuff2[1:1] atol=1e-6 - @test dp1 ≈ dstuff2[2:2] atol=1e-6 + @show du01 dp1 dstuff1 dstuff2 + @test du01≈dstuff1[1:1] atol=1e-6 + @test dp1≈dstuff1[2:2] atol=1e-6 + @test du01≈dstuff2[1:1] atol=1e-6 + @test dp1≈dstuff2[2:2] atol=1e-6 end @testset "ForwardDiffSensitivity: Discrete callbacks" begin - g(u) = sum(u.^2) - @testset "reset to initial condition" begin - affecttimes = range(0.0, 1.0, length=6)[2:end] - u0 = [1.0] - condition(u, t, integrator) = t ∈ affecttimes - affect!(integrator) = (integrator.u .= u0; @show "triggered!") - cb = DiscreteCallback(condition, affect!, save_positions=(false,false)) - test_discrete_callback(cb, affecttimes, g) - end + g(u) = sum(u .^ 2) + @testset "reset to initial condition" begin + affecttimes = range(0.0, 1.0, length = 6)[2:end] + u0 = [1.0] + condition(u, t, integrator) = t ∈ affecttimes + affect!(integrator) = (integrator.u .= u0; @show "triggered!") + cb = DiscreteCallback(condition, affect!, save_positions = (false, false)) + test_discrete_callback(cb, affecttimes, g) + end end diff --git a/test/callbacks/vector_continuous_callbacks.jl b/test/callbacks/vector_continuous_callbacks.jl index 00fda49ab..47c9e4b68 100644 --- a/test/callbacks/vector_continuous_callbacks.jl +++ b/test/callbacks/vector_continuous_callbacks.jl @@ -6,50 +6,53 @@ reltol = 1e-12 savingtimes = 0.5 # see https://diffeq.sciml.ai/stable/features/callback_functions/#VectorContinuousCallback-Example -function test_vector_continuous_callback(cb,g) - - function f(du, u, p, t) - du[1] = u[2] - du[2] = -p[1] - du[3] = u[4] - du[4] = 0.0 - end - - u0 = [50.0, 0.0, 0.0, 2.0] - tspan = (0.0, 10.0) - p = [9.8, 0.9] - prob = ODEProblem(f,u0,tspan,p) - sol = solve(prob, Tsit5(), callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes) - - du01, dp1 = @time Zygote.gradient( - (u0, p) -> g(solve(prob, Tsit5(), u0=u0, p=p, callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=BacksolveAdjoint())), - u0, p) - - dstuff = @time ForwardDiff.gradient( - (θ) -> g(solve(prob, Tsit5(), u0=θ[1:4], p=θ[5:6], callback=cb, abstol=abstol, reltol=reltol, saveat=savingtimes)), - [u0; p]) +function test_vector_continuous_callback(cb, g) + function f(du, u, p, t) + du[1] = u[2] + du[2] = -p[1] + du[3] = u[4] + du[4] = 0.0 + end - @test du01 ≈ dstuff[1:4] - @test dp1 ≈ dstuff[5:6] + u0 = [50.0, 0.0, 0.0, 2.0] + tspan = (0.0, 10.0) + p = [9.8, 0.9] + prob = ODEProblem(f, u0, tspan, p) + sol = solve(prob, Tsit5(), callback = cb, abstol = abstol, reltol = reltol, + saveat = savingtimes) + + du01, dp1 = @time Zygote.gradient((u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, abstol = abstol, + reltol = reltol, + saveat = savingtimes, + sensealg = BacksolveAdjoint())), + u0, p) + + dstuff = @time ForwardDiff.gradient((θ) -> g(solve(prob, Tsit5(), u0 = θ[1:4], + p = θ[5:6], callback = cb, + abstol = abstol, reltol = reltol, + saveat = savingtimes)), + [u0; p]) + + @test du01 ≈ dstuff[1:4] + @test dp1 ≈ dstuff[5:6] end -@testset "VectorContinuous callbacks" begin - @testset "MSE loss function bouncing-ball like" begin - g(u) = sum((1.0.-u).^2)./2 +@testset "VectorContinuous callbacks" begin @testset "MSE loss function bouncing-ball like" begin + g(u) = sum((1.0 .- u) .^ 2) ./ 2 function condition(out, u, t, integrator) # Event when event_f(u,t) == 0 - out[1] = u[1] - out[2] = (u[3]-10.0)u[3] + out[1] = u[1] + out[2] = (u[3] - 10.0)u[3] end @testset "callback with linear affect" begin - function affect!(integrator, idx) - if idx == 1 - integrator.u[2] = -integrator.p[2]*integrator.u[2] - elseif idx == 2 - integrator.u[4] = -integrator.p[2]*integrator.u[4] + function affect!(integrator, idx) + if idx == 1 + integrator.u[2] = -integrator.p[2] * integrator.u[2] + elseif idx == 2 + integrator.u[4] = -integrator.p[2] * integrator.u[4] + end end - end - cb = VectorContinuousCallback(condition, affect!, 2) - test_vector_continuous_callback(cb, g) + cb = VectorContinuousCallback(condition, affect!, 2) + test_vector_continuous_callback(cb, g) end - end -end +end end diff --git a/test/complex_adjoints.jl b/test/complex_adjoints.jl index e23efbdad..841a50c9d 100644 --- a/test/complex_adjoints.jl +++ b/test/complex_adjoints.jl @@ -1,67 +1,72 @@ -using SciMLSensitivity, OrdinaryDiffEq, Zygote, LinearAlgebra, FiniteDiff, Test -A = [1.0*im 2.0; 3.0 4.0] -u0 = [1.0 0.0*im; 0.0 1.0] -tspan = (0.0, 1.0) - -function f(u,p,t) - (A*u)*(p[1]*t + p[2]*t^2 + p[3]*t^3 + p[4]*t^4) -end - -p = [1.5 + im, 1.0, 3.0, 1.0] -prob = ODEProblem{false}(f,u0,tspan,p) - -utarget = [0.0*im 1.0; 1.0 0.0] - -function loss_adjoint(p) - ufinal = last(solve(prob, Tsit5(), p=p, abstol=1e-12, reltol=1e-12, sensealg = InterpolatingAdjoint())) - loss = 1 - abs(tr(ufinal*utarget')/2)^2 - return loss -end - -grad1 = Zygote.gradient(loss_adjoint,Complex{Float64}[1.5, 1.0, 3.0, 1.0])[1] -grad2 = FiniteDiff.finite_difference_gradient(loss_adjoint,Complex{Float64}[1.5, 1.0, 3.0, 1.0]) -@test grad1 ≈ grad2 - -function rhs(u, p, t) - p .* u -end - -function loss_fun(sol) - final_u = sol[:, end] - err = sum(abs.(final_u)) - return err -end - -function inner_loop(prob, p, loss_fun; sensealg = InterpolatingAdjoint()) - sol = solve(prob, Tsit5(), p=p, saveat=0.1; sensealg) - err = loss_fun(sol) - return err -end - -tspan = (0.0, 1.0) -p = [1.0] -u0=[1.0, 2.0] -prob = ODEProblem(rhs, u0, tspan, p) -grads = Zygote.gradient((p)->inner_loop(prob, p, loss_fun), p)[1] - -u0=[1.0 + 2.0*im, 2.0 + 1.0*im] -prob = ODEProblem(rhs, u0, tspan, p) -dp1 = Zygote.gradient((p)->inner_loop(prob, p, loss_fun), p)[1] -dp2 = Zygote.gradient((p)->inner_loop(prob, p, loss_fun; sensealg = QuadratureAdjoint()), p)[1] -dp3 = Zygote.gradient((p)->inner_loop(prob, p, loss_fun; sensealg = BacksolveAdjoint()), p)[1] -@test dp1 ≈ dp2 ≈ dp3 -@test eltype(dp1) <: Float64 - -function fiip(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] - du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] -end -p = [1.5,1.0,3.0,1.0]; u0 = [1.0; 1.0] -prob = ODEProblem(fiip,complex(u0),(0.0,10.0),complex(p)) - -function sum_of_solution(u0, p) - _prob = remake(prob,u0=u0,p=p) - real(sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1))) -end - -dx = Zygote.gradient(sum_of_solution, complex(u0), complex(p)) +using SciMLSensitivity, OrdinaryDiffEq, Zygote, LinearAlgebra, FiniteDiff, Test +A = [1.0*im 2.0; 3.0 4.0] +u0 = [1.0 0.0*im; 0.0 1.0] +tspan = (0.0, 1.0) + +function f(u, p, t) + (A * u) * (p[1] * t + p[2] * t^2 + p[3] * t^3 + p[4] * t^4) +end + +p = [1.5 + im, 1.0, 3.0, 1.0] +prob = ODEProblem{false}(f, u0, tspan, p) + +utarget = [0.0*im 1.0; 1.0 0.0] + +function loss_adjoint(p) + ufinal = last(solve(prob, Tsit5(), p = p, abstol = 1e-12, reltol = 1e-12, + sensealg = InterpolatingAdjoint())) + loss = 1 - abs(tr(ufinal * utarget') / 2)^2 + return loss +end + +grad1 = Zygote.gradient(loss_adjoint, Complex{Float64}[1.5, 1.0, 3.0, 1.0])[1] +grad2 = FiniteDiff.finite_difference_gradient(loss_adjoint, + Complex{Float64}[1.5, 1.0, 3.0, 1.0]) +@test grad1 ≈ grad2 + +function rhs(u, p, t) + p .* u +end + +function loss_fun(sol) + final_u = sol[:, end] + err = sum(abs.(final_u)) + return err +end + +function inner_loop(prob, p, loss_fun; sensealg = InterpolatingAdjoint()) + sol = solve(prob, Tsit5(), p = p, saveat = 0.1; sensealg) + err = loss_fun(sol) + return err +end + +tspan = (0.0, 1.0) +p = [1.0] +u0 = [1.0, 2.0] +prob = ODEProblem(rhs, u0, tspan, p) +grads = Zygote.gradient((p) -> inner_loop(prob, p, loss_fun), p)[1] + +u0 = [1.0 + 2.0 * im, 2.0 + 1.0 * im] +prob = ODEProblem(rhs, u0, tspan, p) +dp1 = Zygote.gradient((p) -> inner_loop(prob, p, loss_fun), p)[1] +dp2 = Zygote.gradient((p) -> inner_loop(prob, p, loss_fun; sensealg = QuadratureAdjoint()), + p)[1] +dp3 = Zygote.gradient((p) -> inner_loop(prob, p, loss_fun; sensealg = BacksolveAdjoint()), + p)[1] +@test dp1 ≈ dp2 ≈ dp3 +@test eltype(dp1) <: Float64 + +function fiip(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2] +end +p = [1.5, 1.0, 3.0, 1.0]; +u0 = [1.0; 1.0]; +prob = ODEProblem(fiip, complex(u0), (0.0, 10.0), complex(p)) + +function sum_of_solution(u0, p) + _prob = remake(prob, u0 = u0, p = p) + real(sum(solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1))) +end + +dx = Zygote.gradient(sum_of_solution, complex(u0), complex(p)) diff --git a/test/complex_no_u.jl b/test/complex_no_u.jl index fedb62614..7c589b9b7 100644 --- a/test/complex_no_u.jl +++ b/test/complex_no_u.jl @@ -1,6 +1,6 @@ using OrdinaryDiffEq, SciMLSensitivity, LinearAlgebra, Optimization, OptimizationFlux, Flux -nn = Chain(Dense(1,16),Dense(16,16,tanh),Dense(16,2)) -initial,re = Flux.destructure(nn) +nn = Chain(Dense(1, 16), Dense(16, 16, tanh), Dense(16, 2)) +initial, re = Flux.destructure(nn) function ode2!(u, p, t) f1, f2 = re(p)([t]) .+ im @@ -8,13 +8,14 @@ function ode2!(u, p, t) end tspan = (0.0, 10.0) -prob = ODEProblem(ode2!, Complex{Float64}[0;0], tspan, initial) +prob = ODEProblem(ode2!, Complex{Float64}[0; 0], tspan, initial) function loss(p) - sol = last(solve(prob, Tsit5(), p=p, sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP(allow_nothing=true)))) - return norm(sol) + sol = last(solve(prob, Tsit5(), p = p, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP(allow_nothing = true)))) + return norm(sol) end -optf = Optimization.OptimizationFunction((x,p) -> loss(x), Optimization.AutoZygote()) +optf = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optf, initial) -res = Optimization.solve(optprob, ADAM(0.1), maxiters = 100) \ No newline at end of file +res = Optimization.solve(optprob, ADAM(0.1), maxiters = 100) diff --git a/test/concrete_solve_derivatives.jl b/test/concrete_solve_derivatives.jl index 278e28dd3..9b5e89bc3 100644 --- a/test/concrete_solve_derivatives.jl +++ b/test/concrete_solve_derivatives.jl @@ -1,273 +1,454 @@ -using SciMLSensitivity, OrdinaryDiffEq, Zygote -using Test, ForwardDiff -import Tracker, ReverseDiff, ChainRulesCore - -function fiip(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] - du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] -end -function foop(u,p,t) - dx = p[1]*u[1] - p[2]*u[1]*u[2] - dy = -p[3]*u[2] + p[4]*u[1]*u[2] - [dx,dy] -end - -p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0] -prob = ODEProblem(fiip,u0,(0.0,10.0),p) -proboop = ODEProblem(foop,u0,(0.0,10.0),p) - -sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14) -@test sol isa ODESolution -sumsol = sum(sol) -@test sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14)) == sumsol -@test sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,sensealg=ForwardDiffSensitivity())) == sumsol -@test sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,sensealg=BacksolveAdjoint())) == sumsol - -### -### adjoint -### - -_sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14) -ū0, adj = adjoint_sensitivities(_sol, Tsit5(), t=0.0:0.1:10, dg_discrete=((out, u, p, t, i)->out .= 1), abstol=1e-14, reltol=1e-14) -du01,dp1 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=QuadratureAdjoint())),u0,p) -du02,dp2 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=InterpolatingAdjoint())),u0,p) -du03,dp3 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=BacksolveAdjoint())),u0,p) -du04,dp4 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=TrackerAdjoint())),u0,p) -@test_broken Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ZygoteAdjoint())),u0,p) isa Tuple -du06,dp6 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ReverseDiffAdjoint())),u0,p) -du07,dp7 = Zygote.gradient((u0,p)->sum(concrete_solve(prob,Tsit5(),u0,p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=BacksolveAdjoint())),u0,p) - -csol = concrete_solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14) - -@test ū0 ≈ du01 rtol=1e-12 -@test ū0 == du02 -@test ū0 ≈ du03 rtol=1e-12 -@test ū0 ≈ du04 rtol=1e-12 -#@test ū0 ≈ du05 rtol=1e-12 -@test ū0 ≈ du06 rtol=1e-12 -@test ū0 ≈ du07 rtol=1e-12 -@test adj ≈ dp1' rtol=1e-12 -@test adj == dp2' -@test adj ≈ dp3' rtol=1e-12 -@test adj ≈ dp4' rtol=1e-12 -#@test adj ≈ dp5' rtol=1e-12 -@test adj ≈ dp6' rtol=1e-12 -@test adj ≈ dp7' rtol=1e-12 - -### -### Direct from prob -### - -du01,dp1 = Zygote.gradient(u0,p) do u0,p - sum(solve(remake(prob,u0=u0,p=p),Tsit5(),abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=QuadratureAdjoint())) -end - -@test ū0 ≈ du01 rtol=1e-12 -@test adj ≈ dp1' rtol=1e-12 - -### -### forward -### - -du06,dp6 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ForwardSensitivity())),u0,p) -du07,dp7 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ForwardDiffSensitivity())),u0,p) -@test_broken du08,dp8 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs = 1:1,sensealg=ForwardSensitivity())),u0,p) -du09,dp9 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs = 1:1,sensealg=ForwardDiffSensitivity())),u0,p) - -@test du06 isa Nothing -@test ū0 ≈ du07 rtol=1e-12 -@test adj ≈ dp6' rtol=1e-12 -@test adj ≈ dp7' rtol=1e-12 - -ū02,adj2 = Zygote.gradient((u0,p)->sum(Array(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=InterpolatingAdjoint()))[1,:]),u0,p) -du05,dp5 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1:1,sensealg=InterpolatingAdjoint())),u0,p) -du06,dp6 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.0:0.1:10.0,save_idxs=1:1,sensealg=QuadratureAdjoint())),u0,p) -du07,dp7 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1,sensealg=InterpolatingAdjoint())),u0,p) -du08,dp8 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1:1,sensealg=InterpolatingAdjoint())),u0,p) -du09,dp9 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1,sensealg=ReverseDiffAdjoint())),u0,p) -du010,dp10 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1:1,sensealg=TrackerAdjoint())),u0,p) -@test_broken du011,dp11 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1:1,sensealg=ForwardSensitivity())),u0,p) -du012,dp12 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1:1,sensealg=ForwardDiffSensitivity())),u0,p) - -@test ū02 ≈ du05 rtol=1e-12 -@test ū02 ≈ du06 rtol=1e-12 -@test ū02 ≈ du07 rtol=1e-12 -@test ū02 ≈ du08 rtol=1e-12 -@test ū02 ≈ du09 rtol=1e-12 -@test ū02 ≈ du010 rtol=1e-12 -#@test ū02 ≈ du011 rtol=1e-12 -@test ū02 ≈ du012 rtol=1e-12 -@test adj2 ≈ dp5 rtol=1e-12 -@test adj2 ≈ dp6 rtol=1e-12 -@test adj2 ≈ dp7 rtol=1e-12 -@test adj2 ≈ dp8 rtol=1e-12 -@test adj2 ≈ dp9 rtol=1e-12 -@test adj2 ≈ dp10 rtol=1e-12 -#@test adj2 ≈ dp11 rtol=1e-12 -@test adj2 ≈ dp12 rtol=1e-12 - -### -### Only End -### - -ū0,adj = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,save_everystep=false,save_start=false,sensealg=InterpolatingAdjoint())),u0,p) -du03,dp3 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,save_everystep=false,save_start=false,sensealg=ReverseDiffAdjoint())),u0,p) -du04,dp4 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,save_everystep=false,save_start=false,sensealg=InterpolatingAdjoint())[end]),u0,p) -@test ū0 ≈ du03 rtol=1e-11 -@test ū0 ≈ du04 rtol=1e-12 -@test adj ≈ dp3 rtol=1e-12 -@test adj ≈ dp4 rtol=1e-12 - -### -### OOPs -### - -_sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14) -ū0,adj = adjoint_sensitivities(_sol,Tsit5(), t=0.0:0.1:10, dg_discrete=((out,u,p,t,i) -> out .= 1),abstol=1e-14,reltol=1e-14) - -### -### adjoint -### - -du01,dp1 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=QuadratureAdjoint())),u0,p) -du02,dp2 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=InterpolatingAdjoint())),u0,p) -du03,dp3 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=BacksolveAdjoint())),u0,p) -du04,dp4 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=TrackerAdjoint())),u0,p) -@test_broken du05,dp5 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ZygoteAdjoint())),u0,p) isa Tuple -du06,dp6 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ReverseDiffAdjoint())),u0,p) - -@test ū0 ≈ du01 rtol=1e-12 -@test ū0 ≈ du02 rtol=1e-12 -@test ū0 ≈ du03 rtol=1e-12 -@test ū0 ≈ du04 rtol=1e-12 -#@test ū0 ≈ du05 rtol=1e-12 -@test ū0 ≈ du06 rtol=1e-12 -@test adj ≈ dp1' rtol=1e-12 -@test adj ≈ dp2' rtol=1e-12 -@test adj ≈ dp3' rtol=1e-12 -@test adj ≈ dp4' rtol=1e-12 -#@test adj ≈ dp5' rtol=1e-12 -@test adj ≈ dp6' rtol=1e-12 - -### -### forward -### - -@test_broken Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ForwardSensitivity())),u0,p) isa Tuple -du07,dp7 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ForwardDiffSensitivity())),u0,p) - -#@test du06 === nothing -@test du07 ≈ ū0 rtol=1e-12 -#@test adj ≈ dp6' rtol=1e-12 -@test adj ≈ dp7' rtol=1e-12 - -ū02,adj2 = Zygote.gradient((u0,p)->sum(Array(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=InterpolatingAdjoint()))[1,:]),u0,p) -du05,dp5 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1:1,sensealg=InterpolatingAdjoint())),u0,p) -du06,dp6 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.0:0.1:10.0,save_idxs=1:1,sensealg=QuadratureAdjoint())),u0,p) -du07,dp7 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1,sensealg=InterpolatingAdjoint())),u0,p) -du08,dp8 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1:1,sensealg=InterpolatingAdjoint())),u0,p) -du09,dp9 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1,sensealg=ReverseDiffAdjoint())),u0,p) -du010,dp10 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1:1,sensealg=TrackerAdjoint())),u0,p) -@test_broken du011,dp11 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1:1,sensealg=ForwardSensitivity())),u0,p) -du012,dp12 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1:1,sensealg=ForwardDiffSensitivity())),u0,p) -# Redundent to test aliasing -du013,dp13 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,save_idxs=1:1,sensealg=InterpolatingAdjoint())),u0,p) -du014,dp14 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,save_idxs=1,saveat=0.1,sensealg=InterpolatingAdjoint())),u0,p) - - -@test ū02 ≈ du05 rtol=1e-12 -@test ū02 ≈ du06 rtol=1e-12 -@test ū02 ≈ du07 rtol=1e-12 -@test ū02 ≈ du08 rtol=1e-12 -@test ū02 ≈ du09 rtol=1e-12 -@test ū02 ≈ du010 rtol=1e-12 -#@test ū02 ≈ du011 rtol=1e-12 -@test ū02 ≈ du012 rtol=1e-12 -@test ū02 ≈ du013 rtol=1e-12 -@test ū02 ≈ du014 rtol=1e-12 -@test adj2 ≈ dp5 rtol=1e-12 -@test adj2 ≈ dp6 rtol=1e-12 -@test adj2 ≈ dp7 rtol=1e-12 -@test adj2 ≈ dp8 rtol=1e-12 -@test adj2 ≈ dp9 rtol=1e-12 -@test adj2 ≈ dp10 rtol=1e-12 -#@test adj2 ≈ dp11 rtol=1e-12 -@test adj2 ≈ dp12 rtol=1e-12 -@test adj2 ≈ dp13 rtol=1e-12 -@test adj2 ≈ dp14 rtol=1e-12 - -# Handle VecOfArray Derivatives -dp1 = Zygote.gradient((p)->sum(last(solve(prob,Tsit5(),p=p,saveat=10.0,abstol=1e-14,reltol=1e-14))),p)[1] -dp2 = ForwardDiff.gradient((p)->sum(last(solve(prob,Tsit5(),p=p,saveat=10.0,abstol=1e-14,reltol=1e-14))),p) -@test dp1 ≈ dp2 - -dp1 = Zygote.gradient((p)->sum(last(solve(proboop,Tsit5(),u0=u0,p=p,saveat=10.0,abstol=1e-14,reltol=1e-14))),p)[1] -dp2 = ForwardDiff.gradient((p)->sum(last(solve(proboop,Tsit5(),u0=u0,p=p,saveat=10.0,abstol=1e-14,reltol=1e-14))),p) -@test dp1 ≈ dp2 - - -# tspan[2]-tspan[1] not a multiple of saveat tests -du0,dp = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=2.3,sensealg=ReverseDiffAdjoint())),u0,p) -du01,dp1 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=2.3,sensealg=QuadratureAdjoint())),u0,p) -du02,dp2 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=2.3,sensealg=InterpolatingAdjoint())),u0,p) -du03,dp3 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=2.3,sensealg=BacksolveAdjoint())),u0,p) -du04,dp4 = Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),save_end=true,u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=2.3,sensealg=ForwardDiffSensitivity())),u0,p) - -@test du0 ≈ du01 rtol=1e-12 -@test du0 ≈ du02 rtol=1e-12 -@test du0 ≈ du03 rtol=1e-12 -@test du0 ≈ du04 rtol=1e-12 -@test dp ≈ dp1 rtol=1e-12 -@test dp ≈ dp2 rtol=1e-12 -@test dp ≈ dp3 rtol=1e-12 -@test dp ≈ dp4 rtol=1e-12 - -### -### SDE -### - -using StochasticDiffEq -using Random -seed = 100 - -function σiip(du,u,p,t) - du[1] = p[5]*u[1] - du[2] = p[6]*u[2] -end - -function σoop(u,p,t) - dx = p[5]*u[1] - dy = p[6]*u[2] - [dx,dy] -end - -function σoop(u::Tracker.TrackedArray,p,t) - dx = p[5]*u[1] - dy = p[6]*u[2] - Tracker.collect([dx,dy]) -end - -p = [1.5,1.0,3.0,1.0,0.1,0.1] -u0 = [1.0;1.0] -tarray = collect(0.0:0.01:1) - -prob = SDEProblem(fiip,σiip,u0,(0.0,1.0),p) -proboop = SDEProblem(foop,σoop,u0,(0.0,1.0),p) - - -### -### OOPs -### - -_sol = solve(proboop,EulerHeun(),dt=1e-2,adaptive=false,save_noise=true,seed=seed) -ū0,adj = adjoint_sensitivities(_sol,EulerHeun(), t=tarray, dg_discrete=((out,u,p,t,i) -> out .= 1), sensealg=BacksolveAdjoint()) - -du01,dp1 = Zygote.gradient((u0,p)->sum(solve(proboop,EulerHeun(), - u0=u0,p=p,dt=1e-2,saveat=0.01,sensealg=BacksolveAdjoint(),seed=seed)),u0,p) - -du02,dp2 = Zygote.gradient( - (u0,p)->sum(solve(proboop,EulerHeun(),u0=u0,p=p,dt=1e-2,saveat=0.01,sensealg=ForwardDiffSensitivity(),seed=seed)),u0,p) - -@test isapprox(ū0, du01, rtol = 1e-4) -@test isapprox(adj, dp1', rtol = 1e-4) -@test isapprox(adj, dp2', rtol = 1e-4) +using SciMLSensitivity, OrdinaryDiffEq, Zygote +using Test, ForwardDiff +import Tracker, ReverseDiff, ChainRulesCore + +function fiip(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2] +end +function foop(u, p, t) + dx = p[1] * u[1] - p[2] * u[1] * u[2] + dy = -p[3] * u[2] + p[4] * u[1] * u[2] + [dx, dy] +end + +p = [1.5, 1.0, 3.0, 1.0]; +u0 = [1.0; 1.0]; +prob = ODEProblem(fiip, u0, (0.0, 10.0), p) +proboop = ODEProblem(foop, u0, (0.0, 10.0), p) + +sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) +@test sol isa ODESolution +sumsol = sum(sol) +@test sum(solve(prob, Tsit5(), u0 = u0, p = p, abstol = 1e-14, reltol = 1e-14)) == sumsol +@test sum(solve(prob, Tsit5(), u0 = u0, p = p, abstol = 1e-14, reltol = 1e-14, + sensealg = ForwardDiffSensitivity())) == sumsol +@test sum(solve(prob, Tsit5(), u0 = u0, p = p, abstol = 1e-14, reltol = 1e-14, + sensealg = BacksolveAdjoint())) == sumsol + +### +### adjoint +### + +_sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) +ū0, adj = adjoint_sensitivities(_sol, Tsit5(), t = 0.0:0.1:10, + dg_discrete = ((out, u, p, t, i) -> out .= 1), + abstol = 1e-14, reltol = 1e-14) +du01, dp1 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = QuadratureAdjoint())), u0, p) +du02, dp2 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = InterpolatingAdjoint())), u0, p) +du03, dp3 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = BacksolveAdjoint())), u0, p) +du04, dp4 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, sensealg = TrackerAdjoint())), + u0, p) +@test_broken Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, sensealg = ZygoteAdjoint())), + u0, p) isa Tuple +du06, dp6 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = ReverseDiffAdjoint())), u0, p) +du07, dp7 = Zygote.gradient((u0, p) -> sum(concrete_solve(prob, Tsit5(), u0, p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = BacksolveAdjoint())), + u0, p) + +csol = concrete_solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) + +@test ū0≈du01 rtol=1e-12 +@test ū0 == du02 +@test ū0≈du03 rtol=1e-12 +@test ū0≈du04 rtol=1e-12 +#@test ū0 ≈ du05 rtol=1e-12 +@test ū0≈du06 rtol=1e-12 +@test ū0≈du07 rtol=1e-12 +@test adj≈dp1' rtol=1e-12 +@test adj == dp2' +@test adj≈dp3' rtol=1e-12 +@test adj≈dp4' rtol=1e-12 +#@test adj ≈ dp5' rtol=1e-12 +@test adj≈dp6' rtol=1e-12 +@test adj≈dp7' rtol=1e-12 + +### +### Direct from prob +### + +du01, dp1 = Zygote.gradient(u0, p) do u0, p + sum(solve(remake(prob, u0 = u0, p = p), Tsit5(), abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, sensealg = QuadratureAdjoint())) +end + +@test ū0≈du01 rtol=1e-12 +@test adj≈dp1' rtol=1e-12 + +### +### forward +### + +du06, dp6 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = ForwardSensitivity())), u0, p) +du07, dp7 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = ForwardDiffSensitivity())), u0, + p) +@test_broken du08, dp8 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, + reltol = 1e-14, saveat = 0.1, + save_idxs = 1:1, + sensealg = ForwardSensitivity())), + u0, p) +du09, dp9 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1:1, + sensealg = ForwardDiffSensitivity())), u0, + p) + +@test du06 isa Nothing +@test ū0≈du07 rtol=1e-12 +@test adj≈dp6' rtol=1e-12 +@test adj≈dp7' rtol=1e-12 + +ū02, adj2 = Zygote.gradient((u0, p) -> sum(Array(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = InterpolatingAdjoint()))[1, + :]), + u0, p) +du05, dp5 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1:1, + sensealg = InterpolatingAdjoint())), u0, p) +du06, dp6 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.0:0.1:10.0, save_idxs = 1:1, + sensealg = QuadratureAdjoint())), u0, p) +du07, dp7 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1, + sensealg = InterpolatingAdjoint())), u0, p) +du08, dp8 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1:1, + sensealg = InterpolatingAdjoint())), u0, p) +du09, dp9 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1, + sensealg = ReverseDiffAdjoint())), u0, p) +du010, dp10 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1:1, + sensealg = TrackerAdjoint())), u0, p) +@test_broken du011, dp11 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, + p = p, abstol = 1e-14, + reltol = 1e-14, + saveat = 0.1, + save_idxs = 1:1, + sensealg = ForwardSensitivity())), + u0, p) +du012, dp12 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1:1, + sensealg = ForwardDiffSensitivity())), + u0, p) + +@test ū02≈du05 rtol=1e-12 +@test ū02≈du06 rtol=1e-12 +@test ū02≈du07 rtol=1e-12 +@test ū02≈du08 rtol=1e-12 +@test ū02≈du09 rtol=1e-12 +@test ū02≈du010 rtol=1e-12 +#@test ū02 ≈ du011 rtol=1e-12 +@test ū02≈du012 rtol=1e-12 +@test adj2≈dp5 rtol=1e-12 +@test adj2≈dp6 rtol=1e-12 +@test adj2≈dp7 rtol=1e-12 +@test adj2≈dp8 rtol=1e-12 +@test adj2≈dp9 rtol=1e-12 +@test adj2≈dp10 rtol=1e-12 +#@test adj2 ≈ dp11 rtol=1e-12 +@test adj2≈dp12 rtol=1e-12 + +### +### Only End +### + +ū0, adj = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + save_everystep = false, save_start = false, + sensealg = InterpolatingAdjoint())), u0, p) +du03, dp3 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + save_everystep = false, save_start = false, + sensealg = ReverseDiffAdjoint())), u0, p) +du04, dp4 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + save_everystep = false, save_start = false, + sensealg = InterpolatingAdjoint())[end]), + u0, p) +@test ū0≈du03 rtol=1e-11 +@test ū0≈du04 rtol=1e-12 +@test adj≈dp3 rtol=1e-12 +@test adj≈dp4 rtol=1e-12 + +### +### OOPs +### + +_sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) +ū0, adj = adjoint_sensitivities(_sol, Tsit5(), t = 0.0:0.1:10, + dg_discrete = ((out, u, p, t, i) -> out .= 1), + abstol = 1e-14, reltol = 1e-14) + +### +### adjoint +### + +du01, dp1 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = QuadratureAdjoint())), u0, p) +du02, dp2 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = InterpolatingAdjoint())), u0, p) +du03, dp3 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = BacksolveAdjoint())), u0, p) +du04, dp4 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, sensealg = TrackerAdjoint())), + u0, p) +@test_broken du05, dp5 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, + p = p, abstol = 1e-14, + reltol = 1e-14, saveat = 0.1, + sensealg = ZygoteAdjoint())), + u0, p) isa Tuple +du06, dp6 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = ReverseDiffAdjoint())), u0, p) + +@test ū0≈du01 rtol=1e-12 +@test ū0≈du02 rtol=1e-12 +@test ū0≈du03 rtol=1e-12 +@test ū0≈du04 rtol=1e-12 +#@test ū0 ≈ du05 rtol=1e-12 +@test ū0≈du06 rtol=1e-12 +@test adj≈dp1' rtol=1e-12 +@test adj≈dp2' rtol=1e-12 +@test adj≈dp3' rtol=1e-12 +@test adj≈dp4' rtol=1e-12 +#@test adj ≈ dp5' rtol=1e-12 +@test adj≈dp6' rtol=1e-12 + +### +### forward +### + +@test_broken Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = ForwardSensitivity())), u0, + p) isa Tuple +du07, dp7 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = ForwardDiffSensitivity())), u0, + p) + +#@test du06 === nothing +@test du07≈ū0 rtol=1e-12 +#@test adj ≈ dp6' rtol=1e-12 +@test adj≈dp7' rtol=1e-12 + +ū02, adj2 = Zygote.gradient((u0, p) -> sum(Array(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = InterpolatingAdjoint()))[1, + :]), + u0, p) +du05, dp5 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1:1, + sensealg = InterpolatingAdjoint())), u0, p) +du06, dp6 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.0:0.1:10.0, save_idxs = 1:1, + sensealg = QuadratureAdjoint())), u0, p) +du07, dp7 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1, + sensealg = InterpolatingAdjoint())), u0, p) +du08, dp8 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1:1, + sensealg = InterpolatingAdjoint())), u0, p) +du09, dp9 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1, + sensealg = ReverseDiffAdjoint())), u0, p) +du010, dp10 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1:1, + sensealg = TrackerAdjoint())), u0, p) +@test_broken du011, dp11 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, + p = p, abstol = 1e-14, + reltol = 1e-14, + saveat = 0.1, + save_idxs = 1:1, + sensealg = ForwardSensitivity())), + u0, p) +du012, dp12 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1:1, + sensealg = ForwardDiffSensitivity())), + u0, p) +# Redundent to test aliasing +du013, dp13 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, save_idxs = 1:1, + sensealg = InterpolatingAdjoint())), u0, + p) +du014, dp14 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + save_idxs = 1, saveat = 0.1, + sensealg = InterpolatingAdjoint())), u0, + p) + +@test ū02≈du05 rtol=1e-12 +@test ū02≈du06 rtol=1e-12 +@test ū02≈du07 rtol=1e-12 +@test ū02≈du08 rtol=1e-12 +@test ū02≈du09 rtol=1e-12 +@test ū02≈du010 rtol=1e-12 +#@test ū02 ≈ du011 rtol=1e-12 +@test ū02≈du012 rtol=1e-12 +@test ū02≈du013 rtol=1e-12 +@test ū02≈du014 rtol=1e-12 +@test adj2≈dp5 rtol=1e-12 +@test adj2≈dp6 rtol=1e-12 +@test adj2≈dp7 rtol=1e-12 +@test adj2≈dp8 rtol=1e-12 +@test adj2≈dp9 rtol=1e-12 +@test adj2≈dp10 rtol=1e-12 +#@test adj2 ≈ dp11 rtol=1e-12 +@test adj2≈dp12 rtol=1e-12 +@test adj2≈dp13 rtol=1e-12 +@test adj2≈dp14 rtol=1e-12 + +# Handle VecOfArray Derivatives +dp1 = Zygote.gradient((p) -> sum(last(solve(prob, Tsit5(), p = p, saveat = 10.0, + abstol = 1e-14, reltol = 1e-14))), p)[1] +dp2 = ForwardDiff.gradient((p) -> sum(last(solve(prob, Tsit5(), p = p, saveat = 10.0, + abstol = 1e-14, reltol = 1e-14))), p) +@test dp1 ≈ dp2 + +dp1 = Zygote.gradient((p) -> sum(last(solve(proboop, Tsit5(), u0 = u0, p = p, saveat = 10.0, + abstol = 1e-14, reltol = 1e-14))), p)[1] +dp2 = ForwardDiff.gradient((p) -> sum(last(solve(proboop, Tsit5(), u0 = u0, p = p, + saveat = 10.0, abstol = 1e-14, + reltol = 1e-14))), p) +@test dp1 ≈ dp2 + +# tspan[2]-tspan[1] not a multiple of saveat tests +du0, dp = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, saveat = 2.3, + sensealg = ReverseDiffAdjoint())), u0, p) +du01, dp1 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 2.3, + sensealg = QuadratureAdjoint())), u0, p) +du02, dp2 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 2.3, + sensealg = InterpolatingAdjoint())), u0, p) +du03, dp3 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 2.3, + sensealg = BacksolveAdjoint())), u0, p) +du04, dp4 = Zygote.gradient((u0, p) -> sum(solve(proboop, Tsit5(), save_end = true, u0 = u0, + p = p, abstol = 1e-14, reltol = 1e-14, + saveat = 2.3, + sensealg = ForwardDiffSensitivity())), u0, + p) + +@test du0≈du01 rtol=1e-12 +@test du0≈du02 rtol=1e-12 +@test du0≈du03 rtol=1e-12 +@test du0≈du04 rtol=1e-12 +@test dp≈dp1 rtol=1e-12 +@test dp≈dp2 rtol=1e-12 +@test dp≈dp3 rtol=1e-12 +@test dp≈dp4 rtol=1e-12 + +### +### SDE +### + +using StochasticDiffEq +using Random +seed = 100 + +function σiip(du, u, p, t) + du[1] = p[5] * u[1] + du[2] = p[6] * u[2] +end + +function σoop(u, p, t) + dx = p[5] * u[1] + dy = p[6] * u[2] + [dx, dy] +end + +function σoop(u::Tracker.TrackedArray, p, t) + dx = p[5] * u[1] + dy = p[6] * u[2] + Tracker.collect([dx, dy]) +end + +p = [1.5, 1.0, 3.0, 1.0, 0.1, 0.1] +u0 = [1.0; 1.0] +tarray = collect(0.0:0.01:1) + +prob = SDEProblem(fiip, σiip, u0, (0.0, 1.0), p) +proboop = SDEProblem(foop, σoop, u0, (0.0, 1.0), p) + +### +### OOPs +### + +_sol = solve(proboop, EulerHeun(), dt = 1e-2, adaptive = false, save_noise = true, + seed = seed) +ū0, adj = adjoint_sensitivities(_sol, EulerHeun(), t = tarray, + dg_discrete = ((out, u, p, t, i) -> out .= 1), + sensealg = BacksolveAdjoint()) + +du01, dp1 = Zygote.gradient((u0, p) -> sum(solve(proboop, EulerHeun(), + u0 = u0, p = p, dt = 1e-2, saveat = 0.01, + sensealg = BacksolveAdjoint(), + seed = seed)), u0, p) + +du02, dp2 = Zygote.gradient((u0, p) -> sum(solve(proboop, EulerHeun(), u0 = u0, p = p, + dt = 1e-2, saveat = 0.01, + sensealg = ForwardDiffSensitivity(), + seed = seed)), u0, p) + +@test isapprox(ū0, du01, rtol = 1e-4) +@test isapprox(adj, dp1', rtol = 1e-4) +@test isapprox(adj, dp2', rtol = 1e-4) diff --git a/test/derivative_shapes.jl b/test/derivative_shapes.jl index 0bb0a86fa..2dda791c7 100644 --- a/test/derivative_shapes.jl +++ b/test/derivative_shapes.jl @@ -1,36 +1,38 @@ -using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, Test -A = [0. 1.; 1. 0.; 0 0; 0 0]; -B = [1. 0.; 0. 1.; 0 0; 0 0]; - -utarget = A; -const T = 10.0; - -function f(u, p, t) - return -p[1]*u # just a silly example to demonstrate the issue -end - - -u0 = [1.0 0.0; 0.0 1.0; 0.0 0.0; 0.0 0.0]; - -tspan = (0.0, T) -tsteps = 0.0:T/100.0:T - -p = [1.7, 1.0, 3.0, 1.0] - -prob_ode = ODEProblem(f, u0, tspan, p); - -fd_ode = ForwardDiff.gradient(p) do p - sum(last(solve(prob_ode, Tsit5(),p=p,abstol=1e-12,reltol=1e-12))) -end - -grad_ode = Zygote.gradient(p) do p - sum(last(solve(prob_ode, Tsit5(),p=p,abstol=1e-12,reltol=1e-12))) -end[1] - -@test fd_ode ≈ grad_ode rtol=1e-6 - -grad_ode = Zygote.gradient(p) do p - sum(last(solve(prob_ode, Tsit5(),p=p,abstol=1e-12,reltol=1e-12, sensealg = InterpolatingAdjoint()))) -end[1] - -@test fd_ode ≈ grad_ode rtol=1e-6 +using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, Test +A = [0.0 1.0; 1.0 0.0; 0 0; 0 0] +; +B = [1.0 0.0; 0.0 1.0; 0 0; 0 0] + +; +utarget = A; +const T = 10.0; + +function f(u, p, t) + return -p[1] * u # just a silly example to demonstrate the issue +end + +u0 = [1.0 0.0; 0.0 1.0; 0.0 0.0; 0.0 0.0]; + +tspan = (0.0, T) +tsteps = 0.0:(T / 100.0):T + +p = [1.7, 1.0, 3.0, 1.0] + +prob_ode = ODEProblem(f, u0, tspan, p); + +fd_ode = ForwardDiff.gradient(p) do p + sum(last(solve(prob_ode, Tsit5(), p = p, abstol = 1e-12, reltol = 1e-12))) +end + +grad_ode = Zygote.gradient(p) do p + sum(last(solve(prob_ode, Tsit5(), p = p, abstol = 1e-12, reltol = 1e-12))) +end[1] + +@test fd_ode≈grad_ode rtol=1e-6 + +grad_ode = Zygote.gradient(p) do p + sum(last(solve(prob_ode, Tsit5(), p = p, abstol = 1e-12, reltol = 1e-12, + sensealg = InterpolatingAdjoint()))) +end[1] + +@test fd_ode≈grad_ode rtol=1e-6 diff --git a/test/discrete.jl b/test/discrete.jl index fa99ee228..5445a787f 100644 --- a/test/discrete.jl +++ b/test/discrete.jl @@ -1,15 +1,15 @@ -using OrdinaryDiffEq, SciMLSensitivity, Zygote, Test - -function loss1(p;sensealg=nothing) - f(x,p,t) = [p[1]] - prob = DiscreteProblem(f, [0.0], (1,10), p) - sol = solve(prob, FunctionMap(scale_by_time = true), saveat=[1,2,3]) - return sum(sol) -end -dp1 = Zygote.gradient(loss1,[1.0])[1] -dp2 = Zygote.gradient(x->loss1(x,sensealg=TrackerAdjoint()),[1.0])[1] -dp3 = Zygote.gradient(x->loss1(x,sensealg=ReverseDiffAdjoint()),[1.0])[1] -dp4 = Zygote.gradient(x->loss1(x,sensealg=ForwardDiffSensitivity()),[1.0])[1] -@test dp1 == dp2 -@test dp1 == dp3 -@test dp1 == dp4 +using OrdinaryDiffEq, SciMLSensitivity, Zygote, Test + +function loss1(p; sensealg = nothing) + f(x, p, t) = [p[1]] + prob = DiscreteProblem(f, [0.0], (1, 10), p) + sol = solve(prob, FunctionMap(scale_by_time = true), saveat = [1, 2, 3]) + return sum(sol) +end +dp1 = Zygote.gradient(loss1, [1.0])[1] +dp2 = Zygote.gradient(x -> loss1(x, sensealg = TrackerAdjoint()), [1.0])[1] +dp3 = Zygote.gradient(x -> loss1(x, sensealg = ReverseDiffAdjoint()), [1.0])[1] +dp4 = Zygote.gradient(x -> loss1(x, sensealg = ForwardDiffSensitivity()), [1.0])[1] +@test dp1 == dp2 +@test dp1 == dp3 +@test dp1 == dp4 diff --git a/test/distributed.jl b/test/distributed.jl index c46d9dac5..0c9dc4b02 100644 --- a/test/distributed.jl +++ b/test/distributed.jl @@ -2,30 +2,31 @@ using Distributed, Flux addprocs(2) @everywhere begin - using SciMLSensitivity, OrdinaryDiffEq, Test + using SciMLSensitivity, OrdinaryDiffEq, Test - pa = [1.0] - u0 = [3.0] + pa = [1.0] + u0 = [3.0] end function model4() - prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa) + prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa) - function prob_func(prob, i, repeat) - remake(prob, u0 = 0.5 .+ i/100 .* prob.u0) - end + function prob_func(prob, i, repeat) + remake(prob, u0 = 0.5 .+ i / 100 .* prob.u0) + end - ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) - sim = solve(ensemble_prob, Tsit5(), EnsembleDistributed(), saveat = 0.1, trajectories = 100) + ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) + sim = solve(ensemble_prob, Tsit5(), EnsembleDistributed(), saveat = 0.1, + trajectories = 100) end # loss function -loss() = sum(abs2,1.0.-Array(model4())) +loss() = sum(abs2, 1.0 .- Array(model4())) data = Iterators.repeated((), 10) cb = function () # callback function to observe training - @show loss() + @show loss() end pa = [1.0] @@ -33,6 +34,6 @@ u0 = [3.0] opt = Flux.ADAM(0.1) println("Starting to train") l1 = loss() -Flux.@epochs 10 Flux.train!(loss, Flux.params([pa,u0]), data, opt; cb = cb) +Flux.@epochs 10 Flux.train!(loss, Flux.params([pa, u0]), data, opt; cb = cb) l2 = loss() @test 10l2 < l1 diff --git a/test/ensembles.jl b/test/ensembles.jl index 89b7655ce..0daa5873d 100644 --- a/test/ensembles.jl +++ b/test/ensembles.jl @@ -3,71 +3,72 @@ using Flux, OrdinaryDiffEq, Test pa = [1.0] u0 = [3.0] function model1() - prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa) + prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa) - function prob_func(prob, i, repeat) - remake(prob, u0 = 0.5 .+ i/100 .* prob.u0) - end + function prob_func(prob, i, repeat) + remake(prob, u0 = 0.5 .+ i / 100 .* prob.u0) + end - ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) - sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100) + ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) + sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100) end # loss function -loss() = sum(abs2,1.0.-Array(model1())) +loss() = sum(abs2, 1.0 .- Array(model1())) data = Iterators.repeated((), 10) cb = function () # callback function to observe training - @show loss() + @show loss() end opt = ADAM(0.1) println("Starting to train") l1 = loss() -Flux.@epochs 10 Flux.train!(loss, Flux.params([pa,u0]), data, opt; cb = cb) +Flux.@epochs 10 Flux.train!(loss, Flux.params([pa, u0]), data, opt; cb = cb) l2 = loss() @test 10l2 < l1 function model2() - prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa) + prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa) - function prob_func(prob, i, repeat) - remake(prob, u0 = 0.5 .+ i/100 .* prob.u0) - end + function prob_func(prob, i, repeat) + remake(prob, u0 = 0.5 .+ i / 100 .* prob.u0) + end - ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) - sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100).u + ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) + sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, + trajectories = 100).u end -loss() = sum(abs2,[sum(abs2,1.0.-u) for u in model2()]) +loss() = sum(abs2, [sum(abs2, 1.0 .- u) for u in model2()]) pa = [1.0] u0 = [3.0] opt = ADAM(0.1) println("Starting to train") l1 = loss() -Flux.@epochs 10 Flux.train!(loss, Flux.params([pa,u0]), data, opt; cb = cb) +Flux.@epochs 10 Flux.train!(loss, Flux.params([pa, u0]), data, opt; cb = cb) l2 = loss() @test 10l2 < l1 function model3() - prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa) + prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa) - function prob_func(prob, i, repeat) - remake(prob, u0 = 0.5 .+ i/100 .* prob.u0) - end + function prob_func(prob, i, repeat) + remake(prob, u0 = 0.5 .+ i / 100 .* prob.u0) + end - ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) - sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), saveat = 0.1, trajectories = 100) + ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) + sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), saveat = 0.1, trajectories = 100) end # loss function -loss() = sum(abs2,1.0.-Array(model3())) +loss() = sum(abs2, 1.0 .- Array(model3())) data = Iterators.repeated((), 10) cb = function () # callback function to observe training - @show loss() + @show loss() end pa = [1.0] @@ -75,6 +76,6 @@ u0 = [3.0] opt = ADAM(0.1) println("Starting to train") l1 = loss() -Flux.@epochs 10 Flux.train!(loss, Flux.params([pa,u0]), data, opt; cb = cb) +Flux.@epochs 10 Flux.train!(loss, Flux.params([pa, u0]), data, opt; cb = cb) l2 = loss() @test 10l2 < l1 diff --git a/test/forward.jl b/test/forward.jl index d3d0b531b..2009a5976 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -1,87 +1,89 @@ using SciMLSensitivity, OrdinaryDiffEq, ForwardDiff, Calculus using Test -function fb(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] - du[2] = dy = -t*p[3]*u[2] + t*u[1]*u[2] +function fb(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = dy = -t * p[3] * u[2] + t * u[1] * u[2] end -function jac(J,u,p,t) - (x, y, a, b, c) = (u[1], u[2], p[1], p[2], p[3]) - J[1,1] = a + y * b * -1 - J[2,1] = t * y - J[1,2] = b * x * -1 - J[2,2] = t * c * -1 + t * x +function jac(J, u, p, t) + (x, y, a, b, c) = (u[1], u[2], p[1], p[2], p[3]) + J[1, 1] = a + y * b * -1 + J[2, 1] = t * y + J[1, 2] = b * x * -1 + J[2, 2] = t * c * -1 + t * x end -function paramjac(pJ,u,p,t) - (x, y, a, b, c) = (u[1], u[2], p[1], p[2], p[3]) - pJ[1,1] = x - pJ[2,1] = 0.0 - pJ[1,2] = - x * y - pJ[2,2] = 0.0 - pJ[1,3] = 0.0 - pJ[2,3] = - t * y +function paramjac(pJ, u, p, t) + (x, y, a, b, c) = (u[1], u[2], p[1], p[2], p[3]) + pJ[1, 1] = x + pJ[2, 1] = 0.0 + pJ[1, 2] = -x * y + pJ[2, 2] = 0.0 + pJ[1, 3] = 0.0 + pJ[2, 3] = -t * y end - -f = ODEFunction(fb,jac=jac,paramjac=paramjac) -p = [1.5,1.0,3.0] -prob = ODEForwardSensitivityProblem(f,[1.0;1.0],(0.0,10.0),p) -probInpl = ODEForwardSensitivityProblem(fb,[1.0;1.0],(0.0,10.0),p) -probnoad = ODEForwardSensitivityProblem(fb,[1.0;1.0],(0.0,10.0),p, - ForwardSensitivity(autodiff=false)) -probnoadjacvec = ODEForwardSensitivityProblem(fb,[1.0;1.0],(0.0,10.0),p, - ForwardSensitivity(autodiff=false,autojacvec=true)) -probnoad2 = ODEForwardSensitivityProblem(f,[1.0;1.0],(0.0,10.0),p, - ForwardSensitivity(autodiff=false)) -probvecmat = ODEForwardSensitivityProblem(fb,[1.0;1.0],(0.0,10.0),p, - ForwardSensitivity(autojacvec=false,autojacmat=true)) -sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14) -@test_broken solve(probInpl,KenCarp4(),abstol=1e-14,reltol=1e-14).retcode == :Success -solInpl = solve(probInpl,KenCarp4(autodiff=false),abstol=1e-14,reltol=1e-14) -solInpl2 = solve(probInpl,Rodas4(autodiff=false),abstol=1e-10,reltol=1e-10) -solnoad = solve(probnoad,KenCarp4(autodiff=false),abstol=1e-14,reltol=1e-14) -solnoadjacvec = solve(probnoadjacvec,KenCarp4(autodiff=false),abstol=1e-14,reltol=1e-14) -solnoad2 = solve(probnoad,KenCarp4(autodiff=false),abstol=1e-14,reltol=1e-14) -solvecmat = solve(probvecmat,Tsit5(),abstol=1e-14,reltol=1e-14) - -x = sol[1:sol.prob.f.numindvar,:] +f = ODEFunction(fb, jac = jac, paramjac = paramjac) +p = [1.5, 1.0, 3.0] +prob = ODEForwardSensitivityProblem(f, [1.0; 1.0], (0.0, 10.0), p) +probInpl = ODEForwardSensitivityProblem(fb, [1.0; 1.0], (0.0, 10.0), p) +probnoad = ODEForwardSensitivityProblem(fb, [1.0; 1.0], (0.0, 10.0), p, + ForwardSensitivity(autodiff = false)) +probnoadjacvec = ODEForwardSensitivityProblem(fb, [1.0; 1.0], (0.0, 10.0), p, + ForwardSensitivity(autodiff = false, + autojacvec = true)) +probnoad2 = ODEForwardSensitivityProblem(f, [1.0; 1.0], (0.0, 10.0), p, + ForwardSensitivity(autodiff = false)) +probvecmat = ODEForwardSensitivityProblem(fb, [1.0; 1.0], (0.0, 10.0), p, + ForwardSensitivity(autojacvec = false, + autojacmat = true)) +sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) +@test_broken solve(probInpl, KenCarp4(), abstol = 1e-14, reltol = 1e-14).retcode == :Success +solInpl = solve(probInpl, KenCarp4(autodiff = false), abstol = 1e-14, reltol = 1e-14) +solInpl2 = solve(probInpl, Rodas4(autodiff = false), abstol = 1e-10, reltol = 1e-10) +solnoad = solve(probnoad, KenCarp4(autodiff = false), abstol = 1e-14, reltol = 1e-14) +solnoadjacvec = solve(probnoadjacvec, KenCarp4(autodiff = false), abstol = 1e-14, + reltol = 1e-14) +solnoad2 = solve(probnoad, KenCarp4(autodiff = false), abstol = 1e-14, reltol = 1e-14) +solvecmat = solve(probvecmat, Tsit5(), abstol = 1e-14, reltol = 1e-14) + +x = sol[1:(sol.prob.f.numindvar), :] @test sol(5.0) ≈ solnoad(5.0) @test sol(5.0) ≈ solnoad2(5.0) -@test sol(5.0) ≈ solnoadjacvec(5.0) atol=1e-6 rtol=1e-6 +@test sol(5.0)≈solnoadjacvec(5.0) atol=1e-6 rtol=1e-6 @test sol(5.0) ≈ solInpl(5.0) -@test isapprox(solInpl(5.0), solInpl2(5.0),rtol=1e-5) +@test isapprox(solInpl(5.0), solInpl2(5.0), rtol = 1e-5) @test sol(5.0) ≈ solvecmat(5.0) # Get the sensitivities -da = sol[sol.prob.f.numindvar+1:sol.prob.f.numindvar*2,:] -db = sol[sol.prob.f.numindvar*2+1:sol.prob.f.numindvar*3,:] -dc = sol[sol.prob.f.numindvar*3+1:sol.prob.f.numindvar*4,:] +da = sol[(sol.prob.f.numindvar + 1):(sol.prob.f.numindvar * 2), :] +db = sol[(sol.prob.f.numindvar * 2 + 1):(sol.prob.f.numindvar * 3), :] +dc = sol[(sol.prob.f.numindvar * 3 + 1):(sol.prob.f.numindvar * 4), :] -sense_res1 = [da[:,end] db[:,end] dc[:,end]] +sense_res1 = [da[:, end] db[:, end] dc[:, end]] -prob = ODEForwardSensitivityProblem(f.f,[1.0;1.0],(0.0,10.0),p, - ForwardSensitivity(autojacvec=true)) -sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14,saveat=0.01) -x = sol[1:sol.prob.f.numindvar,:] +prob = ODEForwardSensitivityProblem(f.f, [1.0; 1.0], (0.0, 10.0), p, + ForwardSensitivity(autojacvec = true)) +sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14, saveat = 0.01) +x = sol[1:(sol.prob.f.numindvar), :] # Get the sensitivities -res = sol[1:sol.prob.f.numindvar,:] -da = sol[sol.prob.f.numindvar+1:sol.prob.f.numindvar*2,:] -db = sol[sol.prob.f.numindvar*2+1:sol.prob.f.numindvar*3,:] -dc = sol[sol.prob.f.numindvar*3+1:sol.prob.f.numindvar*4,:] +res = sol[1:(sol.prob.f.numindvar), :] +da = sol[(sol.prob.f.numindvar + 1):(sol.prob.f.numindvar * 2), :] +db = sol[(sol.prob.f.numindvar * 2 + 1):(sol.prob.f.numindvar * 3), :] +dc = sol[(sol.prob.f.numindvar * 3 + 1):(sol.prob.f.numindvar * 4), :] -sense_res2 = [da[:,end] db[:,end] dc[:,end]] +sense_res2 = [da[:, end] db[:, end] dc[:, end]] function test_f(p) - prob = ODEProblem(f,eltype(p).([1.0,1.0]),(0.0,10.0),p) - solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14,save_everystep=false)[end] + prob = ODEProblem(f, eltype(p).([1.0, 1.0]), (0.0, 10.0), p) + solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14, save_everystep = false)[end] end -p = [1.5,1.0,3.0] -fd_res = ForwardDiff.jacobian(test_f,p) -calc_res = Calculus.finite_difference_jacobian(test_f,p) +p = [1.5, 1.0, 3.0] +fd_res = ForwardDiff.jacobian(test_f, p) +calc_res = Calculus.finite_difference_jacobian(test_f, p) @test sense_res1 ≈ sense_res2 ≈ fd_res @test sense_res1 ≈ sense_res2 ≈ calc_res @@ -92,26 +94,23 @@ xall, dpall = extract_local_sensitivities(sol) @test xall == res @test dpall[1] == da -_,dpall_matrix = extract_local_sensitivities(sol,Val(true)) -@test mapreduce(x->x[:, 2], hcat, dpall) == dpall_matrix[2] - +_, dpall_matrix = extract_local_sensitivities(sol, Val(true)) +@test mapreduce(x -> x[:, 2], hcat, dpall) == dpall_matrix[2] -x, dp = extract_local_sensitivities(sol,length(sol.t)) -sense_res2 = reduce(hcat,dp) +x, dp = extract_local_sensitivities(sol, length(sol.t)) +sense_res2 = reduce(hcat, dp) @test sense_res1 == sense_res2 -@test extract_local_sensitivities(sol,sol.t[3]) == extract_local_sensitivities(sol,3) +@test extract_local_sensitivities(sol, sol.t[3]) == extract_local_sensitivities(sol, 3) tmp = similar(sol[1]) -@test extract_local_sensitivities(tmp,sol,sol.t[3]) == extract_local_sensitivities(sol,3) - +@test extract_local_sensitivities(tmp, sol, sol.t[3]) == extract_local_sensitivities(sol, 3) # asmatrix=true @test extract_local_sensitivities(sol, length(sol), true) == (x, sense_res2) @test extract_local_sensitivities(sol, sol.t[end], true) == (x, sense_res2) @test extract_local_sensitivities(tmp, sol, sol.t[end], true) == (x, sense_res2) - # Return type inferred @inferred extract_local_sensitivities(sol, 1) @inferred extract_local_sensitivities(sol, 1, Val(true)) @@ -122,33 +121,31 @@ tmp = similar(sol[1]) ### ForwardDiff version -prob = ODEForwardSensitivityProblem(f.f,[1.0;1.0],(0.0,10.0),p, +prob = ODEForwardSensitivityProblem(f.f, [1.0; 1.0], (0.0, 10.0), p, ForwardDiffSensitivity()) -sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14,saveat=0.01) +sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14, saveat = 0.01) xall, dpall = extract_local_sensitivities(sol) @test xall ≈ res -@test dpall[1] ≈ da atol=1e-9 +@test dpall[1]≈da atol=1e-9 -_,dpall_matrix = extract_local_sensitivities(sol,Val(true)) -@test mapreduce(x->x[:, 2], hcat, dpall) == dpall_matrix[2] +_, dpall_matrix = extract_local_sensitivities(sol, Val(true)) +@test mapreduce(x -> x[:, 2], hcat, dpall) == dpall_matrix[2] -x, dp = extract_local_sensitivities(sol,length(sol.t)) -sense_res2 = reduce(hcat,dp) +x, dp = extract_local_sensitivities(sol, length(sol.t)) +sense_res2 = reduce(hcat, dp) @test fd_res == sense_res2 -@test extract_local_sensitivities(sol,sol.t[3]) == extract_local_sensitivities(sol,3) +@test extract_local_sensitivities(sol, sol.t[3]) == extract_local_sensitivities(sol, 3) tmp = similar(sol[1]) -@test extract_local_sensitivities(tmp,sol,sol.t[3]) == extract_local_sensitivities(sol,3) - +@test extract_local_sensitivities(tmp, sol, sol.t[3]) == extract_local_sensitivities(sol, 3) # asmatrix=true @test extract_local_sensitivities(sol, length(sol), true) == (x, sense_res2) @test extract_local_sensitivities(sol, sol.t[end], true) == (x, sense_res2) @test extract_local_sensitivities(tmp, sol, sol.t[end], true) == (x, sense_res2) - # Return type inferred @inferred extract_local_sensitivities(sol, 1) @inferred extract_local_sensitivities(sol, 1, Val(true)) @@ -156,7 +153,7 @@ tmp = similar(sol[1]) @inferred extract_local_sensitivities(sol, sol.t[3], Val(true)) @inferred extract_local_sensitivities(tmp, sol, sol.t[3]) @inferred extract_local_sensitivities(tmp, sol, sol.t[3], Val(true)) - + # Test mass matrix function rober_MM(du, u, p, t) y₁, y₂, y₃ = u @@ -171,7 +168,7 @@ function rober_no_MM(du, u, p, t) k₁, k₂, k₃ = p du[1] = -k₁ * y₁ + k₃ * y₂ * y₃ du[2] = k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃ - du[3] = k₂*y₂^2 + du[3] = k₂ * y₂^2 nothing end @@ -180,42 +177,53 @@ p = [0.04, 3e7, 1e4] u0 = [1.0, 0.0, 0.0] tspan = (0.0, 12.0) -f_MM= ODEFunction(rober_MM, mass_matrix = M) -f_no_MM= ODEFunction(rober_no_MM) +f_MM = ODEFunction(rober_MM, mass_matrix = M) +f_no_MM = ODEFunction(rober_no_MM) -prob_MM_ForwardSensitivity = ODEForwardSensitivityProblem(f_MM, u0, tspan, p, ForwardSensitivity()) -sol_MM_ForwardSensitivity = solve(prob_MM_ForwardSensitivity , Rodas4(autodiff = false), reltol = 1e-14, abstol = 1e-14) +prob_MM_ForwardSensitivity = ODEForwardSensitivityProblem(f_MM, u0, tspan, p, + ForwardSensitivity()) +sol_MM_ForwardSensitivity = solve(prob_MM_ForwardSensitivity, Rodas4(autodiff = false), + reltol = 1e-14, abstol = 1e-14) -prob_MM_ForwardDiffSensitivity = ODEForwardSensitivityProblem(f_MM, u0, tspan, p, ForwardDiffSensitivity()) -sol_MM_ForwardDiffSensitivity = solve(prob_MM_ForwardDiffSensitivity, Rodas4(autodiff = false), reltol = 1e-14, abstol = 1e-14) +prob_MM_ForwardDiffSensitivity = ODEForwardSensitivityProblem(f_MM, u0, tspan, p, + ForwardDiffSensitivity()) +sol_MM_ForwardDiffSensitivity = solve(prob_MM_ForwardDiffSensitivity, + Rodas4(autodiff = false), reltol = 1e-14, + abstol = 1e-14) prob_no_MM = ODEForwardSensitivityProblem(f_no_MM, u0, tspan, p, ForwardSensitivity()) -sol_no_MM= solve(prob_no_MM, Rodas4(autodiff = false), reltol = 1e-14, abstol = 1e-14) +sol_no_MM = solve(prob_no_MM, Rodas4(autodiff = false), reltol = 1e-14, abstol = 1e-14) -sen_MM_ForwardSensitivity = extract_local_sensitivities(sol_MM_ForwardSensitivity,10.0,true) -sen_MM_ForwardDiffSensitivity = extract_local_sensitivities(sol_MM_ForwardDiffSensitivity,10.0,true) -sen_no_MM = extract_local_sensitivities(sol_no_MM,10.0,true) +sen_MM_ForwardSensitivity = extract_local_sensitivities(sol_MM_ForwardSensitivity, 10.0, + true) +sen_MM_ForwardDiffSensitivity = extract_local_sensitivities(sol_MM_ForwardDiffSensitivity, + 10.0, true) +sen_no_MM = extract_local_sensitivities(sol_no_MM, 10.0, true) -@test sen_MM_ForwardSensitivity[2] ≈ sen_MM_ForwardDiffSensitivity[2] atol=1e-10 rtol=1e-10 -@test sen_MM_ForwardSensitivity[2] ≈ sen_no_MM[2] atol=1e-10 rtol=1e-10 +@test sen_MM_ForwardSensitivity[2]≈sen_MM_ForwardDiffSensitivity[2] atol=1e-10 rtol=1e-10 +@test sen_MM_ForwardSensitivity[2]≈sen_no_MM[2] atol=1e-10 rtol=1e-10 # Test Float32 -function f32(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] - du[2] = dy = -p[3]*u[2] + u[1]*u[2] +function f32(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = dy = -p[3] * u[2] + u[1] * u[2] end -p = [1.5f0,1.0f0,3.0f0] -prob = ODEForwardSensitivityProblem(f32,[1.0f0;1.0f0],(0.0f0,10.0f0),p) -sol = solve(prob,Tsit5()) +p = [1.5f0, 1.0f0, 3.0f0] +prob = ODEForwardSensitivityProblem(f32, [1.0f0; 1.0f0], (0.0f0, 10.0f0), p) +sol = solve(prob, Tsit5()) # Out Of Place Error function lotka_volterra_oop(u, p, t) - du = zeros(2) - du[1] = p[1]*u[1] - p[2]*u[1]*u[2] - du[2] = -p[3]*u[2] + p[4]*u[1]*u[2] - return du + du = zeros(2) + du[1] = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = -p[3] * u[2] + p[4] * u[1] * u[2] + return du end u0 = [1.0, 1.0] p = [1.5, 1.0, 3.0, 1.0] -@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError ODEForwardSensitivityProblem(lotka_volterra_oop, u0, (0.0, 10.0), p) \ No newline at end of file +@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError ODEForwardSensitivityProblem(lotka_volterra_oop, + u0, + (0.0, + 10.0), + p) diff --git a/test/forward_chunking.jl b/test/forward_chunking.jl index ffff44447..f70f6eab9 100644 --- a/test/forward_chunking.jl +++ b/test/forward_chunking.jl @@ -1,108 +1,130 @@ using SciMLSensitivity, OrdinaryDiffEq, Zygote, Test, ForwardDiff -function fiip(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*p[51]*p[75]*u[1]*u[2] - du[2] = dy = -p[3]*p[81]*p[25]*u[2] + (sum(@view(p[4:end]))/100)*u[1]*u[2] +function fiip(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * p[51] * p[75] * u[1] * u[2] + du[2] = dy = -p[3] * p[81] * p[25] * u[2] + (sum(@view(p[4:end])) / 100) * u[1] * u[2] end -function foop(u,p,t) - dx = p[1]*u[1] - p[2]*p[51]*p[75]*u[1]*u[2] - dy = -p[3]*p[81]*p[25]*u[2] + (sum(@view(p[4:end]))/100)*p[4]*u[1]*u[2] - [dx,dy] +function foop(u, p, t) + dx = p[1] * u[1] - p[2] * p[51] * p[75] * u[1] * u[2] + dy = -p[3] * p[81] * p[25] * u[2] + (sum(@view(p[4:end])) / 100) * p[4] * u[1] * u[2] + [dx, dy] end -p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0] -p = reshape(vcat(p,ones(100)),4,26) -prob = ODEProblem(fiip,u0,(0.0,10.0),p) -proboop = ODEProblem(foop,u0,(0.0,10.0),p) - -loss = (u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ForwardDiffSensitivity())) -@time du01,dp1 = Zygote.gradient(loss,u0,p) - -loss = (u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=InterpolatingAdjoint())) -@time du02,dp2 = Zygote.gradient(loss,u0,p) - -loss = (u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ForwardDiffSensitivity(chunk_size=104))) -@time du03,dp3 = Zygote.gradient(loss,u0,p) - -dp = ForwardDiff.gradient(p->loss(u0,p),p) -du0 = ForwardDiff.gradient(u0->loss(u0,p),u0) - -@test du01 ≈ du0 rtol=1e-12 -@test du01 ≈ du02 rtol=1e-12 -@test du01 ≈ du03 rtol=1e-12 -@test dp1 ≈ dp rtol=1e-12 -@test dp1 ≈ dp2 rtol=1e-12 -@test dp1 ≈ dp3 rtol=1e-12 - -loss = (u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ForwardDiffSensitivity())) -@time du01,dp1 = Zygote.gradient(loss,u0,p) - -loss = (u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=InterpolatingAdjoint())) -@time du02,dp2 = Zygote.gradient(loss,u0,p) - -loss = (u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ForwardDiffSensitivity(chunk_size=104))) -@time du03,dp3 = Zygote.gradient(loss,u0,p) - -dp = ForwardDiff.gradient(p->loss(u0,p),p) -du0 = ForwardDiff.gradient(u0->loss(u0,p),u0) - -@test du01 ≈ du0 rtol=1e-12 -@test du01 ≈ du02 rtol=1e-12 -@test du01 ≈ du03 rtol=1e-12 -@test dp1 ≈ dp rtol=1e-12 -@test dp1 ≈ dp2 rtol=1e-12 -@test dp1 ≈ dp3 rtol=1e-12 - -function fiip(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] - du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] - du[3:end] .= p[4] +p = [1.5, 1.0, 3.0, 1.0]; +u0 = [1.0; 1.0]; +p = reshape(vcat(p, ones(100)), 4, 26) +prob = ODEProblem(fiip, u0, (0.0, 10.0), p) +proboop = ODEProblem(foop, u0, (0.0, 10.0), p) + +loss = (u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, sensealg = ForwardDiffSensitivity())) +@time du01, dp1 = Zygote.gradient(loss, u0, p) + +loss = (u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, sensealg = InterpolatingAdjoint())) +@time du02, dp2 = Zygote.gradient(loss, u0, p) + +loss = (u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = ForwardDiffSensitivity(chunk_size = 104))) +@time du03, dp3 = Zygote.gradient(loss, u0, p) + +dp = ForwardDiff.gradient(p -> loss(u0, p), p) +du0 = ForwardDiff.gradient(u0 -> loss(u0, p), u0) + +@test du01≈du0 rtol=1e-12 +@test du01≈du02 rtol=1e-12 +@test du01≈du03 rtol=1e-12 +@test dp1≈dp rtol=1e-12 +@test dp1≈dp2 rtol=1e-12 +@test dp1≈dp3 rtol=1e-12 + +loss = (u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, abstol = 1e-14, + reltol = 1e-14, saveat = 0.1, + sensealg = ForwardDiffSensitivity())) +@time du01, dp1 = Zygote.gradient(loss, u0, p) + +loss = (u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, abstol = 1e-14, + reltol = 1e-14, saveat = 0.1, + sensealg = InterpolatingAdjoint())) +@time du02, dp2 = Zygote.gradient(loss, u0, p) + +loss = (u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, abstol = 1e-14, + reltol = 1e-14, saveat = 0.1, + sensealg = ForwardDiffSensitivity(chunk_size = 104))) +@time du03, dp3 = Zygote.gradient(loss, u0, p) + +dp = ForwardDiff.gradient(p -> loss(u0, p), p) +du0 = ForwardDiff.gradient(u0 -> loss(u0, p), u0) + +@test du01≈du0 rtol=1e-12 +@test du01≈du02 rtol=1e-12 +@test du01≈du03 rtol=1e-12 +@test dp1≈dp rtol=1e-12 +@test dp1≈dp2 rtol=1e-12 +@test dp1≈dp3 rtol=1e-12 + +function fiip(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2] + du[3:end] .= p[4] end -function foop(u,p,t) - dx = p[1]*u[1] - p[2]*u[1]*u[2] - dy = -p[3]*u[2] + p[4]*u[1]*u[2] - reshape(vcat(dx,dy,repeat([p[4]],100)),2,51) +function foop(u, p, t) + dx = p[1] * u[1] - p[2] * u[1] * u[2] + dy = -p[3] * u[2] + p[4] * u[1] * u[2] + reshape(vcat(dx, dy, repeat([p[4]], 100)), 2, 51) end -p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0] -u0 = reshape(vcat(u0,ones(100)),2,51) -prob = ODEProblem(fiip,u0,(0.0,10.0),p) -proboop = ODEProblem(foop,u0,(0.0,10.0),p) - -loss = (u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ForwardDiffSensitivity())) -@time du01,dp1 = Zygote.gradient(loss,u0,p) - -loss = (u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=InterpolatingAdjoint())) -@time du02,dp2 = Zygote.gradient(loss,u0,p) - -loss = (u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ForwardDiffSensitivity(chunk_size=102))) -@time du03,dp3 = Zygote.gradient(loss,u0,p) - -dp = ForwardDiff.gradient(p->loss(u0,p),p) -du0 = ForwardDiff.gradient(u0->loss(u0,p),u0) - -@test du01 ≈ du0 rtol=1e-12 -@test du01 ≈ du02 rtol=1e-12 -@test du01 ≈ du03 rtol=1e-12 -@test dp1 ≈ dp rtol=1e-12 -@test dp1 ≈ dp2 rtol=1e-12 -@test dp1 ≈ dp3 rtol=1e-12 - -loss = (u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ForwardDiffSensitivity())) -@time du01,dp1 = Zygote.gradient(loss,u0,p) - -loss = (u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=InterpolatingAdjoint())) -@time du02,dp2 = Zygote.gradient(loss,u0,p) - -loss = (u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ForwardDiffSensitivity(chunk_size=102))) -@time du03,dp3 = Zygote.gradient(loss,u0,p) - -dp = ForwardDiff.gradient(p->loss(u0,p),p) -du0 = ForwardDiff.gradient(u0->loss(u0,p),u0) - -@test du01 ≈ du0 rtol=1e-12 -@test du01 ≈ du02 rtol=1e-12 -@test du01 ≈ du03 rtol=1e-12 -@test dp1 ≈ dp rtol=1e-12 -@test dp1 ≈ dp2 rtol=1e-12 -@test dp1 ≈ dp3 rtol=1e-12 +p = [1.5, 1.0, 3.0, 1.0]; +u0 = [1.0; 1.0]; +u0 = reshape(vcat(u0, ones(100)), 2, 51) +prob = ODEProblem(fiip, u0, (0.0, 10.0), p) +proboop = ODEProblem(foop, u0, (0.0, 10.0), p) + +loss = (u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, sensealg = ForwardDiffSensitivity())) +@time du01, dp1 = Zygote.gradient(loss, u0, p) + +loss = (u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, sensealg = InterpolatingAdjoint())) +@time du02, dp2 = Zygote.gradient(loss, u0, p) + +loss = (u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = ForwardDiffSensitivity(chunk_size = 102))) +@time du03, dp3 = Zygote.gradient(loss, u0, p) + +dp = ForwardDiff.gradient(p -> loss(u0, p), p) +du0 = ForwardDiff.gradient(u0 -> loss(u0, p), u0) + +@test du01≈du0 rtol=1e-12 +@test du01≈du02 rtol=1e-12 +@test du01≈du03 rtol=1e-12 +@test dp1≈dp rtol=1e-12 +@test dp1≈dp2 rtol=1e-12 +@test dp1≈dp3 rtol=1e-12 + +loss = (u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, abstol = 1e-14, + reltol = 1e-14, saveat = 0.1, + sensealg = ForwardDiffSensitivity())) +@time du01, dp1 = Zygote.gradient(loss, u0, p) + +loss = (u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, abstol = 1e-14, + reltol = 1e-14, saveat = 0.1, + sensealg = InterpolatingAdjoint())) +@time du02, dp2 = Zygote.gradient(loss, u0, p) + +loss = (u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, abstol = 1e-14, + reltol = 1e-14, saveat = 0.1, + sensealg = ForwardDiffSensitivity(chunk_size = 102))) +@time du03, dp3 = Zygote.gradient(loss, u0, p) + +dp = ForwardDiff.gradient(p -> loss(u0, p), p) +du0 = ForwardDiff.gradient(u0 -> loss(u0, p), u0) + +@test du01≈du0 rtol=1e-12 +@test du01≈du02 rtol=1e-12 +@test du01≈du03 rtol=1e-12 +@test dp1≈dp rtol=1e-12 +@test dp1≈dp2 rtol=1e-12 +@test dp1≈dp3 rtol=1e-12 diff --git a/test/forward_prob_kwargs.jl b/test/forward_prob_kwargs.jl index f70418557..b6fedc928 100644 --- a/test/forward_prob_kwargs.jl +++ b/test/forward_prob_kwargs.jl @@ -5,29 +5,30 @@ using FiniteDiff using Zygote using ForwardDiff -u0 = [1.0,1.0] -p = [1.5,1.0,3.0,1.0] -function fiip(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] - du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] +u0 = [1.0, 1.0] +p = [1.5, 1.0, 3.0, 1.0] +function fiip(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2] end -prob = ODEProblem(fiip,u0,(0.0,10.0),[1.5,1.0,3.0,1.0],reltol = 1e-14,abstol=1e-14) +prob = ODEProblem(fiip, u0, (0.0, 10.0), [1.5, 1.0, 3.0, 1.0], reltol = 1e-14, + abstol = 1e-14) function cost(p1) - _prob = remake(prob,p=vcat(p1,p[2:end])) - sol = solve(_prob,Tsit5(),sensealg=ForwardDiffSensitivity(),saveat=0.1) + _prob = remake(prob, p = vcat(p1, p[2:end])) + sol = solve(_prob, Tsit5(), sensealg = ForwardDiffSensitivity(), saveat = 0.1) sum(sol) end -res = FiniteDiff.finite_difference_derivative(cost,p[1]) # 8.305557728239275 -res2 = ForwardDiff.derivative(cost,p[1]) # 8.305305252400714 # only 1 dual number -res3 = Zygote.gradient(cost,p[1])[1] # (8.305266428305409,) # 4 dual numbers +res = FiniteDiff.finite_difference_derivative(cost, p[1]) # 8.305557728239275 +res2 = ForwardDiff.derivative(cost, p[1]) # 8.305305252400714 # only 1 dual number +res3 = Zygote.gradient(cost, p[1])[1] # (8.305266428305409,) # 4 dual numbers function cost(p1) - _prob = remake(prob,p=vcat(p1,p[2:end])) - sol = solve(_prob,Tsit5(),sensealg=ForwardSensitivity(),saveat=0.1) + _prob = remake(prob, p = vcat(p1, p[2:end])) + sol = solve(_prob, Tsit5(), sensealg = ForwardSensitivity(), saveat = 0.1) sum(sol) end -res4 = Zygote.gradient(cost,p[1])[1] # (7.720368430265481,) +res4 = Zygote.gradient(cost, p[1])[1] # (7.720368430265481,) @test res ≈ res2 @test res ≈ res3 diff --git a/test/forward_remake.jl b/test/forward_remake.jl index 5ffd1bdba..050c8d5c8 100644 --- a/test/forward_remake.jl +++ b/test/forward_remake.jl @@ -1,35 +1,36 @@ using SciMLSensitivity, ForwardDiff, Distributions, OrdinaryDiffEq, LinearAlgebra, Test -function fiip(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] - du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] +function fiip(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2] end function g(sol) - J = extract_local_sensitivities(sol,true)[2] - det(J'*J) + J = extract_local_sensitivities(sol, true)[2] + det(J' * J) end -u0 = [1.0,1.0] -p = [1.5,1.0,3.0,1.0] -prob = ODEForwardSensitivityProblem(fiip,u0,(0.0,10.0),p,saveat=0:10) +u0 = [1.0, 1.0] +p = [1.5, 1.0, 3.0, 1.0] +prob = ODEForwardSensitivityProblem(fiip, u0, (0.0, 10.0), p, saveat = 0:10) sol = solve(prob, Tsit5()) -u0_dist = [Uniform(0.9,1.1), 1.0] -p_dist = [1.5, truncated(Normal(1.5,.1),1.1, 1.9),3.0,1.0] -u0_dist_extended = vcat(u0_dist,zeros(length(p)*length(u0))) +u0_dist = [Uniform(0.9, 1.1), 1.0] +p_dist = [1.5, truncated(Normal(1.5, 0.1), 1.1, 1.9), 3.0, 1.0] +u0_dist_extended = vcat(u0_dist, zeros(length(p) * length(u0))) function fiip_expe_SciML_forw_sen_SciML() - prob = ODEForwardSensitivityProblem(fiip,u0,(0.0,10.0),p,saveat=0:10) + prob = ODEForwardSensitivityProblem(fiip, u0, (0.0, 10.0), p, saveat = 0:10) prob_func = function (prob, i, repeat) - _prob = remake(prob, u0=[isa(ui,Distribution) ? rand(ui) : ui for ui in u0_dist], p=[isa(pj,Distribution) ? rand(pj) : pj for pj in p_dist]) + _prob = remake(prob, u0 = [isa(ui, Distribution) ? rand(ui) : ui for ui in u0_dist], + p = [isa(pj, Distribution) ? rand(pj) : pj for pj in p_dist]) _prob end output_func = function (sol, i) (g(sol), false) end - monte_prob = EnsembleProblem(prob;output_func=output_func,prob_func=prob_func) - sol = solve(monte_prob,Tsit5(),EnsembleSerial(),trajectories=100_000) + monte_prob = EnsembleProblem(prob; output_func = output_func, prob_func = prob_func) + sol = solve(monte_prob, Tsit5(), EnsembleSerial(), trajectories = 100_000) mean(sol.u) end -@test fiip_expe_SciML_forw_sen_SciML() ≈ 3.56e6 rtol=4e-2 +@test fiip_expe_SciML_forw_sen_SciML()≈3.56e6 rtol=4e-2 diff --git a/test/forwarddiffsensitivity_sparsity_components.jl b/test/forwarddiffsensitivity_sparsity_components.jl index bfb14db0d..2e912ecb2 100644 --- a/test/forwarddiffsensitivity_sparsity_components.jl +++ b/test/forwarddiffsensitivity_sparsity_components.jl @@ -2,34 +2,35 @@ using OrdinaryDiffEq, SciMLSensitivity, Flux using ComponentArrays, LinearAlgebra, Optimization, Test const nknots = 10 -const h = 1.0/(nknots+1) -x = range(0, step=h, length=nknots) -u0 = sin.(π*x) +const h = 1.0 / (nknots + 1) +x = range(0, step = h, length = nknots) +u0 = sin.(π * x) -@inline function f(du,u,p,t) - du .= zero(eltype(u)) - u₃ = @view u[3:end] - u₂ = @view u[2:end-1] - u₁ = @view u[1:end-2] - @. du[2:end-1] = p.k*((u₃ - 2*u₂ + u₁)/(h^2.0)) - nothing +@inline function f(du, u, p, t) + du .= zero(eltype(u)) + u₃ = @view u[3:end] + u₂ = @view u[2:(end - 1)] + u₁ = @view u[1:(end - 2)] + @. du[2:(end - 1)] = p.k * ((u₃ - 2 * u₂ + u₁) / (h^2.0)) + nothing end -p_true = ComponentArray(k=0.42) -jac_proto = Tridiagonal(similar(u0,nknots-1), similar(u0), similar(u0, nknots-1)) -prob = ODEProblem(ODEFunction(f,jac_prototype=jac_proto), u0, (0.0,1.0), p_true) -@time sol_true = solve(prob, Rodas4P(), saveat=0.1) +p_true = ComponentArray(k = 0.42) +jac_proto = Tridiagonal(similar(u0, nknots - 1), similar(u0), similar(u0, nknots - 1)) +prob = ODEProblem(ODEFunction(f, jac_prototype = jac_proto), u0, (0.0, 1.0), p_true) +@time sol_true = solve(prob, Rodas4P(), saveat = 0.1) function loss(p) - _prob = remake(prob, p=p) - sol = solve(_prob, Rodas4P(autodiff=false), saveat=0.1, sensealg=ForwardDiffSensitivity()) - sum((sol .- sol_true).^2) + _prob = remake(prob, p = p) + sol = solve(_prob, Rodas4P(autodiff = false), saveat = 0.1, + sensealg = ForwardDiffSensitivity()) + sum((sol .- sol_true) .^ 2) end -p0 = ComponentArray(k=1.0) +p0 = ComponentArray(k = 1.0) -optf = Optimization.OptimizationFunction((x,p) -> loss(x), Optimization.AutoZygote()) +optf = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optf, p0) res = Optimization.solve(optprob, ADAM(0.01), maxiters = 100) -@test res.u.k ≈ 0.42461977305259074 rtol=1e-1 +@test res.u.k≈0.42461977305259074 rtol=1e-1 diff --git a/test/gdp_regression_test.jl b/test/gdp_regression_test.jl index b0fe03cdb..74ffddd9a 100644 --- a/test/gdp_regression_test.jl +++ b/test/gdp_regression_test.jl @@ -1,6 +1,66 @@ using SciMLSensitivity, Flux, OrdinaryDiffEq, LinearAlgebra, Test -GDP = [11394358246872.6, 11886411296037.9, 12547852149499.6, 13201781525927, 14081902622923.3, 14866223429278.3, 15728198883149.2, 16421593575529.9, 17437921118338, 18504710349537.1, 19191754995907.1, 20025063402734.2, 21171619915190.4, 22549236163304.4, 22999815176366.2, 23138798276196.2, 24359046058098.6, 25317009721600.9, 26301669369287.8, 27386035164588.8, 27907493159394.4, 28445139283067.1, 28565588996657.6, 29255060755937.6, 30574152048605.8, 31710451102539.4, 32786657119472.8, 34001004119223.5, 35570841010027.7, 36878317437617.5, 37952345258555.4, 38490918890678.7, 39171116855465.5, 39772082901255.8, 40969517920094.4, 42210614326789.4, 43638265675924.6, 45254805649669.6, 46411399944618.2, 47929948653387.3, 50036361141742.2, 51009550274808.6, 52127765360545.5, 53644090247696.9, 55995239099025.6, 58161311618934.2, 60681422072544.7, 63240595965946.1, 64413060738139.7, 63326658023605.9, 66036918504601.7, 68100669928597.9, 69811348331640.1, 71662400667935.7, 73698404958519.1, 75802901433146, 77752106717302.4, 80209237761564.8, 82643194654568.3] +GDP = [ + 11394358246872.6, + 11886411296037.9, + 12547852149499.6, + 13201781525927, + 14081902622923.3, + 14866223429278.3, + 15728198883149.2, + 16421593575529.9, + 17437921118338, + 18504710349537.1, + 19191754995907.1, + 20025063402734.2, + 21171619915190.4, + 22549236163304.4, + 22999815176366.2, + 23138798276196.2, + 24359046058098.6, + 25317009721600.9, + 26301669369287.8, + 27386035164588.8, + 27907493159394.4, + 28445139283067.1, + 28565588996657.6, + 29255060755937.6, + 30574152048605.8, + 31710451102539.4, + 32786657119472.8, + 34001004119223.5, + 35570841010027.7, + 36878317437617.5, + 37952345258555.4, + 38490918890678.7, + 39171116855465.5, + 39772082901255.8, + 40969517920094.4, + 42210614326789.4, + 43638265675924.6, + 45254805649669.6, + 46411399944618.2, + 47929948653387.3, + 50036361141742.2, + 51009550274808.6, + 52127765360545.5, + 53644090247696.9, + 55995239099025.6, + 58161311618934.2, + 60681422072544.7, + 63240595965946.1, + 64413060738139.7, + 63326658023605.9, + 66036918504601.7, + 68100669928597.9, + 69811348331640.1, + 71662400667935.7, + 73698404958519.1, + 75802901433146, + 77752106717302.4, + 80209237761564.8, + 82643194654568.3, +] function monomial(cGDP, parameters, t) α1, β1, nu1, nu2, δ, δ2 = parameters @@ -12,12 +72,13 @@ tspan = (1.0, 59.0) p = [474.8501513113645, 0.7036417845990167, 0.0, 1e-10, 1e-10, 1e-10] u0 = [GDP0] if false - prob = ODEProblem(monomial,[GDP0],tspan,p) + prob = ODEProblem(monomial, [GDP0], tspan, p) else ## false crashes. that is when i am tracking the initial conditions - prob = ODEProblem(monomial,u0,tspan,p) + prob = ODEProblem(monomial, u0, tspan, p) end function predict_rd() # Our 1-layer neural network - Array(solve(prob,Tsit5(),p=p,saveat=1.0:1.0:59.0,reltol=1e-4,sensealg=TrackerAdjoint())) + Array(solve(prob, Tsit5(), p = p, saveat = 1.0:1.0:59.0, reltol = 1e-4, + sensealg = TrackerAdjoint())) end function loss_rd() ##L2 norm biases the newer times unfairly @@ -25,8 +86,8 @@ function loss_rd() ##L2 norm biases the newer times unfairly c = 0.0 a = predict_rd() d = 0.0 - for i=1:59 - c += (a[i][1]/GDP[i]-1)^2 ## L2 of relative error + for i in 1:59 + c += (a[i][1] / GDP[i] - 1)^2 ## L2 of relative error end c + 3 * d end @@ -36,11 +97,11 @@ opt = ADAM(0.01) peek = function () #callback function to observe training #reduces training speed by a lot - println("Loss: ",loss_rd()) + println("Loss: ", loss_rd()) end peek() -Flux.train!(loss_rd, Flux.params(p,u0), data, opt, cb=peek) +Flux.train!(loss_rd, Flux.params(p, u0), data, opt, cb = peek) peek() @test loss_rd() < 0.2 @@ -50,19 +111,19 @@ function monomial(dcGDP, cGDP, parameters, t) dcGDP[1] = α1 * ((cGDP[1]))^β1 end - GDP0 = GDP[1] tspan = (1.0, 59.0) p = [474.8501513113645, 0.7036417845990167, 0.0, 1e-10, 1e-10, 1e-10] u0 = [GDP0] if false - prob = ODEProblem(monomial,[GDP0],tspan,p) + prob = ODEProblem(monomial, [GDP0], tspan, p) else ## false crashes. that is when i am tracking the initial conditions - prob = ODEProblem(monomial,u0,tspan,p) + prob = ODEProblem(monomial, u0, tspan, p) end function predict_adjoint() # Our 1-layer neural network - Array(solve(prob,Tsit5(),p=p,saveat=1.0,reltol=1e-4,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))) + Array(solve(prob, Tsit5(), p = p, saveat = 1.0, reltol = 1e-4, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true)))) end function loss_adjoint() ##L2 norm biases the newer times unfairly @@ -70,8 +131,8 @@ function loss_adjoint() ##L2 norm biases the newer times unfairly c = 0.0 a = predict_adjoint() d = 0.0 - for i=1:59 - c += (a[i][1]/GDP[i]-1)^2 ## L2 of relative error + for i in 1:59 + c += (a[i][1] / GDP[i] - 1)^2 ## L2 of relative error end c + 3 * d end @@ -81,10 +142,10 @@ opt = ADAM(0.01) peek = function () #callback function to observe training #reduces training speed by a lot - println("Loss: ",loss_adjoint()) + println("Loss: ", loss_adjoint()) end peek() -Flux.train!(loss_adjoint, Flux.params(p,u0), data, opt, cb=peek) +Flux.train!(loss_adjoint, Flux.params(p, u0), data, opt, cb = peek) peek() @test loss_adjoint() < 0.2 diff --git a/test/gpu/diffeqflux_standard_gpu.jl b/test/gpu/diffeqflux_standard_gpu.jl index db5797d97..0ac297d41 100644 --- a/test/gpu/diffeqflux_standard_gpu.jl +++ b/test/gpu/diffeqflux_standard_gpu.jl @@ -1,40 +1,39 @@ -using SciMLSensitivity, OrdinaryDiffEq, Flux, DiffEqFlux, CUDA, Zygote -CUDA.allowscalar(false) # Makes sure no slow operations are occuring - -# Generate Data -u0 = Float32[2.0; 0.0] -datasize = 30 -tspan = (0.0f0, 1.5f0) -tsteps = range(tspan[1], tspan[2], length = datasize) -function trueODEfunc(du, u, p, t) - true_A = Float32[-0.1 2.0; -2.0 -0.1] - du .= ((u.^3)'true_A)' -end -prob_trueode = ODEProblem(trueODEfunc, u0, tspan) -# Make the data into a GPU-based array if the user has a GPU -ode_data = gpu(solve(prob_trueode, Tsit5(), saveat = tsteps)) - - -dudt2 = Chain(x -> x.^3, - Dense(2, 50, tanh), - Dense(50, 2)) |> gpu -u0 = Float32[2.0; 0.0] |> gpu - -_p,re = Flux.destructure(dudt2) -p = gpu(_p) - -prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) - -function predict_neuralode(p) - gpu(prob_neuralode(u0,p)) -end -function loss_neuralode(p) - pred = predict_neuralode(p) - loss = sum(abs2, ode_data .- pred) - return loss -end -# Callback function to observe training -list_plots = [] -iter = 0 - -Zygote.gradient(loss_neuralode, p) \ No newline at end of file +using SciMLSensitivity, OrdinaryDiffEq, Flux, DiffEqFlux, CUDA, Zygote +CUDA.allowscalar(false) # Makes sure no slow operations are occuring + +# Generate Data +u0 = Float32[2.0; 0.0] +datasize = 30 +tspan = (0.0f0, 1.5f0) +tsteps = range(tspan[1], tspan[2], length = datasize) +function trueODEfunc(du, u, p, t) + true_A = Float32[-0.1 2.0; -2.0 -0.1] + du .= ((u .^ 3)'true_A)' +end +prob_trueode = ODEProblem(trueODEfunc, u0, tspan) +# Make the data into a GPU-based array if the user has a GPU +ode_data = gpu(solve(prob_trueode, Tsit5(), saveat = tsteps)) + +dudt2 = Chain(x -> x .^ 3, + Dense(2, 50, tanh), + Dense(50, 2)) |> gpu +u0 = Float32[2.0; 0.0] |> gpu + +_p, re = Flux.destructure(dudt2) +p = gpu(_p) + +prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) + +function predict_neuralode(p) + gpu(prob_neuralode(u0, p)) +end +function loss_neuralode(p) + pred = predict_neuralode(p) + loss = sum(abs2, ode_data .- pred) + return loss +end +# Callback function to observe training +list_plots = [] +iter = 0 + +Zygote.gradient(loss_neuralode, p) diff --git a/test/gpu/mixed_gpu_cpu_adjoint.jl b/test/gpu/mixed_gpu_cpu_adjoint.jl index 7f08b9544..ceec5f877 100644 --- a/test/gpu/mixed_gpu_cpu_adjoint.jl +++ b/test/gpu/mixed_gpu_cpu_adjoint.jl @@ -5,10 +5,10 @@ CUDA.allowscalar(false) H = CuArray(rand(Float32, 2, 2)) ann = Chain(Dense(1, 4, tanh)) -p,re = Flux.destructure(ann) +p, re = Flux.destructure(ann) function func(x, p, t) - (re(p)([t])[1]*H)*x + (re(p)([t])[1] * H) * x end x0 = CuArray(rand(Float32, 2)) @@ -17,19 +17,19 @@ x1 = CuArray(rand(Float32, 2)) prob = ODEProblem(func, x0, (0.0f0, 1.0f0)) function evolve(p) - solve(prob, Tsit5(), p=p, save_start=false, - save_everystep=false, abstol=1e-4, reltol=1e-4, - sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())).u[1] + solve(prob, Tsit5(), p = p, save_start = false, + save_everystep = false, abstol = 1e-4, reltol = 1e-4, + sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())).u[1] end function cost(p) x = evolve(p) - c = sum(abs,x - x1) + c = sum(abs, x - x1) #println(c) c end -grad = Zygote.gradient(cost,p)[1] +grad = Zygote.gradient(cost, p)[1] @test !iszero(grad[1]) @test iszero(grad[2:4]) @test !iszero(grad[5]) @@ -42,25 +42,25 @@ grad = Zygote.gradient(cost,p)[1] rng = MersenneTwister(1234) m = 32 n = 16 -Z = randn(rng, Float32, (n,m)) |> gpu -𝒯 = 2f0 -Δτ = 1f-1 -ca_init = [zeros(1) ; ones(m)] |> gpu +Z = randn(rng, Float32, (n, m)) |> gpu +𝒯 = 2.0f0 +Δτ = 1.0f-1 +ca_init = [zeros(1); ones(m)] |> gpu function f(ca, Z, t) - a = ca[2:end] - a_unit = a / sum(a) - w_unit = Z*a_unit - Ka_unit = Z'*w_unit - z_unit = dot(abs.(Ka_unit), a_unit) - aKa_over_z = a .* Ka_unit / z_unit - [sum(aKa_over_z) / m; -abs.(aKa_over_z)] |> gpu + a = ca[2:end] + a_unit = a / sum(a) + w_unit = Z * a_unit + Ka_unit = Z' * w_unit + z_unit = dot(abs.(Ka_unit), a_unit) + aKa_over_z = a .* Ka_unit / z_unit + [sum(aKa_over_z) / m; -abs.(aKa_over_z)] |> gpu end function c(Z) - prob = ODEProblem(f, ca_init, (0f0,𝒯), Z, saveat=Δτ) - sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(), saveat=Δτ) - sum(last(sol.u)) + prob = ODEProblem(f, ca_init, (0.0f0, 𝒯), Z, saveat = Δτ) + sol = solve(prob, Tsit5(), sensealg = BacksolveAdjoint(), saveat = Δτ) + sum(last(sol.u)) end println("forward:", c(Z)) diff --git a/test/hasbranching.jl b/test/hasbranching.jl index 025673de3..ec65955ee 100644 --- a/test/hasbranching.jl +++ b/test/hasbranching.jl @@ -1,9 +1,9 @@ using SciMLSensitivity, Test @test SciMLSensitivity.hasbranching(1, 2) do x, y - (x < 0 ? -x : x) + exp(y) + (x < 0 ? -x : x) + exp(y) end @test !SciMLSensitivity.hasbranching(1, 2) do x, y - ifelse(x < 0, -x, x) + exp(y) + ifelse(x < 0, -x, x) + exp(y) end diff --git a/test/hybrid_de.jl b/test/hybrid_de.jl index 51e30127d..be9e07754 100644 --- a/test/hybrid_de.jl +++ b/test/hybrid_de.jl @@ -1,56 +1,57 @@ using Flux, SciMLSensitivity, DiffEqCallbacks, OrdinaryDiffEq, Test # , Plots -u0 = Float32[2.; 0.] +u0 = Float32[2.0; 0.0] datasize = 100 -tspan = (0.0f0,10.5f0) -dosetimes = [1.0,2.0,4.0,8.0] +tspan = (0.0f0, 10.5f0) +dosetimes = [1.0, 2.0, 4.0, 8.0] function affect!(integrator) - integrator.u = integrator.u.+1 + integrator.u = integrator.u .+ 1 end -cb_ = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false)) -function trueODEfunc(du,u,p,t) +cb_ = PresetTimeCallback(dosetimes, affect!, save_positions = (false, false)) +function trueODEfunc(du, u, p, t) du .= -u end -t = range(tspan[1],tspan[2],length=datasize) +t = range(tspan[1], tspan[2], length = datasize) -prob = ODEProblem(trueODEfunc,u0,tspan) -ode_data = Array(solve(prob,Tsit5(),callback=cb_,saveat=t)) -dudt2 = Chain(Dense(2,50,tanh), - Dense(50,2)) -p,re = Flux.destructure(dudt2) # use this p as the initial condition! +prob = ODEProblem(trueODEfunc, u0, tspan) +ode_data = Array(solve(prob, Tsit5(), callback = cb_, saveat = t)) +dudt2 = Chain(Dense(2, 50, tanh), + Dense(50, 2)) +p, re = Flux.destructure(dudt2) # use this p as the initial condition! -function dudt(du,u,p,t) +function dudt(du, u, p, t) du[1:2] .= -u[1:2] du[3:end] .= re(p)(u[1:2]) #re(p)(u[3:end]) end -z0 = Float32[u0;u0] -prob = ODEProblem(dudt,z0,tspan) +z0 = Float32[u0; u0] +prob = ODEProblem(dudt, z0, tspan) affect!(integrator) = integrator.u[1:2] .= integrator.u[3:end] -cb = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false)) +cb = PresetTimeCallback(dosetimes, affect!, save_positions = (false, false)) function predict_n_ode() - _prob = remake(prob,p=p) - Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cb,saveat=t,sensealg=ReverseDiffAdjoint()))[1:2,:] + _prob = remake(prob, p = p) + Array(solve(_prob, Tsit5(), u0 = z0, p = p, callback = cb, saveat = t, + sensealg = ReverseDiffAdjoint()))[1:2, :] # Array(solve(prob,Tsit5(),u0=z0,p=p,saveat=t))[1:2,:] end function loss_n_ode() pred = predict_n_ode() - loss = sum(abs2,ode_data .- pred) + loss = sum(abs2, ode_data .- pred) loss end loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE -cba = function (;doplot=false) #callback function to observe training - pred = predict_n_ode() - display(sum(abs2,ode_data .- pred)) - # plot current prediction against data - # pl = scatter(t,ode_data[1,:],label="data") - # scatter!(pl,t,pred[1,:],label="prediction") - # display(plot(pl)) - return false +cba = function (; doplot = false) #callback function to observe training + pred = predict_n_ode() + display(sum(abs2, ode_data .- pred)) + # plot current prediction against data + # pl = scatter(t,ode_data[1,:],label="data") + # scatter!(pl,t,pred[1,:],label="prediction") + # display(plot(pl)) + return false end cba() diff --git a/test/layers.jl b/test/layers.jl index 08b3084f4..fdc5a0f14 100644 --- a/test/layers.jl +++ b/test/layers.jl @@ -1,22 +1,23 @@ using SciMLSensitivity, Flux, Zygote, OrdinaryDiffEq, Test # , Plots -function lotka_volterra(du,u,p,t) - x, y = u - α, β, δ, γ = p - du[1] = dx = (α - β*y)x - du[2] = dy = (δ*x - γ)y +function lotka_volterra(du, u, p, t) + x, y = u + α, β, δ, γ = p + du[1] = dx = (α - β * y)x + du[2] = dy = (δ * x - γ)y end p = [2.2, 1.0, 2.0, 0.4] -u0 = [1.0,1.0] -prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p) +u0 = [1.0, 1.0] +prob = ODEProblem(lotka_volterra, u0, (0.0, 10.0), p) # Reverse-mode function predict_rd(p) - Array(solve(prob,Tsit5(),p=p,saveat=0.1,reltol=1e-4,sensealg=TrackerAdjoint())) + Array(solve(prob, Tsit5(), p = p, saveat = 0.1, reltol = 1e-4, + sensealg = TrackerAdjoint())) end -loss_rd(p) = sum(abs2,x-1 for x in predict_rd(p)) -loss_rd() = sum(abs2,x-1 for x in predict_rd(p)) +loss_rd(p) = sum(abs2, x - 1 for x in predict_rd(p)) +loss_rd() = sum(abs2, x - 1 for x in predict_rd(p)) loss_rd() grads = Zygote.gradient(loss_rd, p) @@ -24,8 +25,8 @@ grads = Zygote.gradient(loss_rd, p) opt = ADAM(0.1) cb = function () - display(loss_rd()) - # display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6))) + display(loss_rd()) + # display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6))) end # Display the ODE with the current parameter values. @@ -38,9 +39,10 @@ loss2 = loss_rd() p = [2.2, 1.0, 2.0, 0.4] function predict_fd() - vec(Array(solve(prob,Tsit5(),p=p,saveat=0.0:0.1:1.0,reltol=1e-4,sensealg=ForwardDiffSensitivity()))) + vec(Array(solve(prob, Tsit5(), p = p, saveat = 0.0:0.1:1.0, reltol = 1e-4, + sensealg = ForwardDiffSensitivity()))) end -loss_fd() = sum(abs2,x-1 for x in predict_fd()) +loss_fd() = sum(abs2, x - 1 for x in predict_fd()) loss_fd() ps = Flux.params(p) @@ -50,8 +52,8 @@ grads = Zygote.gradient(loss_fd, ps) data = Iterators.repeated((), 100) opt = ADAM(0.1) cb = function () - display(loss_fd()) - # display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6))) + display(loss_fd()) + # display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6))) end # Display the ODE with the current parameter values. @@ -64,9 +66,9 @@ loss2 = loss_fd() p = [2.2, 1.0, 2.0, 0.4] ps = Flux.params(p) function predict_adjoint() - solve(remake(prob,p=p),Tsit5(),saveat=0.1,reltol=1e-4) + solve(remake(prob, p = p), Tsit5(), saveat = 0.1, reltol = 1e-4) end -loss_reduction(sol) = sum(abs2,x-1 for x in vec(sol)) +loss_reduction(sol) = sum(abs2, x - 1 for x in vec(sol)) loss_adjoint() = loss_reduction(predict_adjoint()) loss_adjoint() @@ -76,8 +78,8 @@ grads = Zygote.gradient(loss_adjoint, ps) data = Iterators.repeated((), 100) opt = ADAM(0.1) cb = function () - display(loss_adjoint()) - # display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6))) + display(loss_adjoint()) + # display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6))) end # Display the ODE with the current parameter values. diff --git a/test/layers_dde.jl b/test/layers_dde.jl index 02095f736..28331d3e5 100644 --- a/test/layers_dde.jl +++ b/test/layers_dde.jl @@ -1,27 +1,29 @@ using SciMLSensitivity, Flux, Zygote, DelayDiffEq, Test ## Setup DDE to optimize -function delay_lotka_volterra(du,u,h,p,t) - x, y = u - α, β, δ, γ = p - du[1] = dx = (α - β*y)*h(p,t-0.1)[1] - du[2] = dy = (δ*x - γ)*y +function delay_lotka_volterra(du, u, h, p, t) + x, y = u + α, β, δ, γ = p + du[1] = dx = (α - β * y) * h(p, t - 0.1)[1] + du[2] = dy = (δ * x - γ) * y end -h(p,t) = ones(eltype(p),2) -prob = DDEProblem(delay_lotka_volterra,[1.0,1.0],h,(0.0,10.0),constant_lags=[0.1]) +h(p, t) = ones(eltype(p), 2) +prob = DDEProblem(delay_lotka_volterra, [1.0, 1.0], h, (0.0, 10.0), constant_lags = [0.1]) p = [2.2, 1.0, 2.0, 0.4] function predict_fd_dde(p) - solve(prob,MethodOfSteps(Tsit5()),p=p,saveat=0.0:0.1:10.0,reltol=1e-4,sensealg=ForwardDiffSensitivity())[1,:] + solve(prob, MethodOfSteps(Tsit5()), p = p, saveat = 0.0:0.1:10.0, reltol = 1e-4, + sensealg = ForwardDiffSensitivity())[1, :] end -loss_fd_dde(p) = sum(abs2,x-1 for x in predict_fd_dde(p)) +loss_fd_dde(p) = sum(abs2, x - 1 for x in predict_fd_dde(p)) loss_fd_dde(p) -@test !iszero(Zygote.gradient(loss_fd_dde,p)[1]) +@test !iszero(Zygote.gradient(loss_fd_dde, p)[1]) function predict_rd_dde(p) - solve(prob,MethodOfSteps(Tsit5()),p=p,saveat=0.1,reltol=1e-4,sensealg=TrackerAdjoint())[1,:] + solve(prob, MethodOfSteps(Tsit5()), p = p, saveat = 0.1, reltol = 1e-4, + sensealg = TrackerAdjoint())[1, :] end -loss_rd_dde(p) = sum(abs2,x-1 for x in predict_rd_dde(p)) +loss_rd_dde(p) = sum(abs2, x - 1 for x in predict_rd_dde(p)) loss_rd_dde(p) -@test !iszero(Zygote.gradient(loss_rd_dde,p)[1]) +@test !iszero(Zygote.gradient(loss_rd_dde, p)[1]) -@test Zygote.gradient(loss_fd_dde,p)[1] ≈ Zygote.gradient(loss_rd_dde,p)[1] rtol=1e-2 +@test Zygote.gradient(loss_fd_dde, p)[1]≈Zygote.gradient(loss_rd_dde, p)[1] rtol=1e-2 diff --git a/test/layers_sde.jl b/test/layers_sde.jl index 55a43374e..36db20d70 100644 --- a/test/layers_sde.jl +++ b/test/layers_sde.jl @@ -1,53 +1,55 @@ using SciMLSensitivity, Flux, Zygote, StochasticDiffEq, Test -function lotka_volterra(du,u,p,t) - x, y = u - α, β, δ, γ = p - du[1] = dx = α*x - β*x*y - du[2] = dy = -δ*y + γ*x*y -end -function lotka_volterra(u,p,t) - x, y = u - α, β, δ, γ = p - dx = α*x - β*x*y - dy = -δ*y + γ*x*y - [dx,dy] -end -function lotka_volterra_noise(du,u,p,t) - du[1] = 0.01u[1] - du[2] = 0.01u[2] -end -function lotka_volterra_noise(u,p,t) - [0.01u[1],0.01u[2]] -end -prob = SDEProblem(lotka_volterra,lotka_volterra_noise,[1.0,1.0],(0.0,10.0)) +function lotka_volterra(du, u, p, t) + x, y = u + α, β, δ, γ = p + du[1] = dx = α * x - β * x * y + du[2] = dy = -δ * y + γ * x * y +end +function lotka_volterra(u, p, t) + x, y = u + α, β, δ, γ = p + dx = α * x - β * x * y + dy = -δ * y + γ * x * y + [dx, dy] +end +function lotka_volterra_noise(du, u, p, t) + du[1] = 0.01u[1] + du[2] = 0.01u[2] +end +function lotka_volterra_noise(u, p, t) + [0.01u[1], 0.01u[2]] +end +prob = SDEProblem(lotka_volterra, lotka_volterra_noise, [1.0, 1.0], (0.0, 10.0)) p = [2.2, 1.0, 2.0, 0.4] function predict_fd_sde(p) - solve(prob,SOSRI(),p=p,saveat=0.0:0.1:0.5,sensealg=ForwardDiffSensitivity())[1,:] + solve(prob, SOSRI(), p = p, saveat = 0.0:0.1:0.5, sensealg = ForwardDiffSensitivity())[1, + :] end -loss_fd_sde(p) = sum(abs2,x-1 for x in predict_fd_sde(p)) +loss_fd_sde(p) = sum(abs2, x - 1 for x in predict_fd_sde(p)) loss_fd_sde(p) -prob = SDEProblem{false}(lotka_volterra,lotka_volterra_noise,[1.0,1.0],(0.0,10.0)) +prob = SDEProblem{false}(lotka_volterra, lotka_volterra_noise, [1.0, 1.0], (0.0, 10.0)) p = [2.2, 1.0, 2.0, 0.4] function predict_fd_sde(p) - solve(prob,SOSRI(),p=p,saveat=0.0:0.1:0.5,sensealg=ForwardDiffSensitivity())[1,:] + solve(prob, SOSRI(), p = p, saveat = 0.0:0.1:0.5, sensealg = ForwardDiffSensitivity())[1, + :] end -loss_fd_sde(p) = sum(abs2,x-1 for x in predict_fd_sde(p)) +loss_fd_sde(p) = sum(abs2, x - 1 for x in predict_fd_sde(p)) loss_fd_sde(p) -@test !iszero(Zygote.gradient(loss_fd_sde,p)[1]) +@test !iszero(Zygote.gradient(loss_fd_sde, p)[1]) -prob = SDEProblem(lotka_volterra,lotka_volterra_noise,[1.0,1.0],(0.0,0.5)) +prob = SDEProblem(lotka_volterra, lotka_volterra_noise, [1.0, 1.0], (0.0, 0.5)) function predict_rd_sde(p) - solve(prob,SOSRI(),p=p,saveat=0.0:0.1:0.5,sensealg=TrackerAdjoint())[1,:] + solve(prob, SOSRI(), p = p, saveat = 0.0:0.1:0.5, sensealg = TrackerAdjoint())[1, :] end -loss_rd_sde(p) = sum(abs2,x-1 for x in predict_rd_sde(p)) -@test !iszero(Zygote.gradient(loss_rd_sde,p)[1]) +loss_rd_sde(p) = sum(abs2, x - 1 for x in predict_rd_sde(p)) +@test !iszero(Zygote.gradient(loss_rd_sde, p)[1]) -prob = SDEProblem{false}(lotka_volterra,lotka_volterra_noise,[1.0,1.0],(0.0,0.5)) +prob = SDEProblem{false}(lotka_volterra, lotka_volterra_noise, [1.0, 1.0], (0.0, 0.5)) function predict_rd_sde(p) - solve(prob,SOSRI(),p=p,saveat=0.0:0.1:0.5,sensealg=TrackerAdjoint())[1,:] + solve(prob, SOSRI(), p = p, saveat = 0.0:0.1:0.5, sensealg = TrackerAdjoint())[1, :] end -loss_rd_sde(p) = sum(abs2,x-1 for x in predict_rd_sde(p)) -@test !iszero(Zygote.gradient(loss_rd_sde,p)[1]) +loss_rd_sde(p) = sum(abs2, x - 1 for x in predict_rd_sde(p)) +@test !iszero(Zygote.gradient(loss_rd_sde, p)[1]) diff --git a/test/literal_adjoint.jl b/test/literal_adjoint.jl index 4d39b4b1f..a6518ab00 100644 --- a/test/literal_adjoint.jl +++ b/test/literal_adjoint.jl @@ -1,23 +1,23 @@ using SciMLSensitivity, OrdinaryDiffEq, Zygote, Test function lv!(du, u, p, t) - x,y = u + x, y = u a, b, c, d = p - du[1] = a*x - b*x*y - du[2] = -c*y + d*x*y + du[1] = a * x - b * x * y + du[2] = -c * y + d * x * y end -function test(u0,p) - tspan = [0.,1.] +function test(u0, p) + tspan = [0.0, 1.0] prob = ODEProblem(lv!, u0, tspan, p) - sol = solve(prob,Tsit5()) + sol = solve(prob, Tsit5()) return sol.u[end][1] end -function test2(u0,p) - tspan = [0.,1.] +function test2(u0, p) + tspan = [0.0, 1.0] prob = ODEProblem(lv!, u0, tspan, p) - sol = solve(prob,Tsit5()) - return Array(sol)[1,end] + sol = solve(prob, Tsit5()) + return Array(sol)[1, end] end -u0 = [1.,1.] -p = [1.,1.,1.,1.] -@test Zygote.gradient(test,u0,p) == Zygote.gradient(test2,u0,p) +u0 = [1.0, 1.0] +p = [1.0, 1.0, 1.0, 1.0] +@test Zygote.gradient(test, u0, p) == Zygote.gradient(test2, u0, p) diff --git a/test/mixed_costs.jl b/test/mixed_costs.jl index 9977ce36a..899531362 100644 --- a/test/mixed_costs.jl +++ b/test/mixed_costs.jl @@ -9,20 +9,21 @@ reltol = 1e-14 savingtimes = collect(1.0:9.0) function fiip(du, u, p, t) - du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] - du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2] + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2] end ## Continuous cost functionals function continuous_cost_forward(input) - u0 = input[1:2] - p = input[3:end] + u0 = input[1:2] + p = input[3:end] - prob = ODEProblem(fiip, u0, (0.0, 10.0), p) - sol = solve(prob, Tsit5(), abstol=abstol, reltol=reltol) - cost, err = quadgk((t) -> sol(t)[1]^2 + p[1], prob.tspan..., atol=abstol, rtol=reltol) - cost + prob = ODEProblem(fiip, u0, (0.0, 10.0), p) + sol = solve(prob, Tsit5(), abstol = abstol, reltol = reltol) + cost, err = quadgk((t) -> sol(t)[1]^2 + p[1], prob.tspan..., atol = abstol, + rtol = reltol) + cost end p = [1.5, 1.0, 3.0, 1.0] u0 = [1.0; 1.0] @@ -33,222 +34,335 @@ dFiniteDiff = FiniteDiff.finite_difference_gradient(continuous_cost_forward, inp @test dForwardDiff ≈ dFiniteDiff prob = ODEProblem(fiip, u0, (0.0, 10.0), p) -sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol) +sol = solve(prob, Tsit5(), reltol = reltol, abstol = abstol) g(u, p, t) = u[1]^2 + p[1] function dgdu(out, u, p, t) - out[1] = 2u[1] - out[2] = 0.0 + out[1] = 2u[1] + out[2] = 0.0 end function dgdp(out, u, p, t) - out[1] = 1.0 - out[2] = 0.0 - out[3] = 0.0 - out[4] = 0.0 + out[1] = 1.0 + out[2] = 0.0 + out[3] = 0.0 + out[4] = 0.0 end # BacksolveAdjoint, all vjps -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=BacksolveAdjoint(autojacvec=TrackerVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = BacksolveAdjoint(autojacvec = TrackerVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=BacksolveAdjoint(autojacvec=false), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = BacksolveAdjoint(autojacvec = false), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] # InterpolatingAdjoint, all vjps -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=InterpolatingAdjoint(autojacvec=TrackerVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = InterpolatingAdjoint(autojacvec = TrackerVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=InterpolatingAdjoint(autojacvec=false), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = InterpolatingAdjoint(autojacvec = false), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] # QuadratureAdjoint, all vjps -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=QuadratureAdjoint(autojacvec=EnzymeVJP(), abstol=abstol, reltol=reltol), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = QuadratureAdjoint(autojacvec = EnzymeVJP(), + abstol = abstol, + reltol = reltol), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(), abstol=abstol, reltol=reltol), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(), + abstol = abstol, + reltol = reltol), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous=(dgdu, dgdp), g=g, sensealg=QuadratureAdjoint(autojacvec=false, abstol=abstol, reltol=reltol), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), dg_continuous = (dgdu, dgdp), g = g, + sensealg = QuadratureAdjoint(autojacvec = false, + abstol = abstol, + reltol = reltol), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] ## ## Discrete costs -function discrete_cost_forward(input, sensealg=nothing) - u0 = input[1:2] - p = input[3:end] - - prob = ODEProblem(fiip, u0, (0.0, 10.0), p) - sol = Array(solve(prob, Tsit5(), abstol=abstol, reltol=reltol, saveat=savingtimes, sensealg=sensealg, save_start=false, save_end=false)) - cost = zero(eltype(p)) - for u in eachcol(sol) - cost += u[1]^2 #+ p[1] - end - cost +function discrete_cost_forward(input, sensealg = nothing) + u0 = input[1:2] + p = input[3:end] + + prob = ODEProblem(fiip, u0, (0.0, 10.0), p) + sol = Array(solve(prob, Tsit5(), abstol = abstol, reltol = reltol, saveat = savingtimes, + sensealg = sensealg, save_start = false, save_end = false)) + cost = zero(eltype(p)) + for u in eachcol(sol) + cost += u[1]^2 #+ p[1] + end + cost end dForwardDiff = ForwardDiff.gradient(discrete_cost_forward, input) dFiniteDiff = FiniteDiff.finite_difference_gradient(discrete_cost_forward, input) @test dForwardDiff ≈ dFiniteDiff function dgdu(out, u, p, t, i) - out[1] = 2u[1] - out[2] = 0.0 + out[1] = 2u[1] + out[2] = 0.0 end function dgdp(out, u, p, t, i) - out[1] = 0.0 - out[2] = 0.0 - out[3] = 0.0 - out[4] = 0.0 + out[1] = 0.0 + out[2] = 0.0 + out[3] = 0.0 + out[4] = 0.0 end # BacksolveAdjoint, all vjps -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=BacksolveAdjoint(autojacvec=TrackerVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = BacksolveAdjoint(autojacvec = TrackerVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=BacksolveAdjoint(autojacvec=false), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = BacksolveAdjoint(autojacvec = false), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] # InterpolatingAdjoint, all vjps -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=InterpolatingAdjoint(autojacvec=TrackerVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = InterpolatingAdjoint(autojacvec = TrackerVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=InterpolatingAdjoint(autojacvec=false), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = InterpolatingAdjoint(autojacvec = false), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] # QuadratureAdjoint, all vjps -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=QuadratureAdjoint(autojacvec=EnzymeVJP(), abstol=abstol, reltol=reltol), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = QuadratureAdjoint(autojacvec = EnzymeVJP(), + abstol = abstol, + reltol = reltol), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(), abstol=abstol, reltol=reltol), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(), + abstol = abstol, + reltol = reltol), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu, sensealg=QuadratureAdjoint(autojacvec=false, abstol=abstol, reltol=reltol), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu, + sensealg = QuadratureAdjoint(autojacvec = false, + abstol = abstol, + reltol = reltol), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] # concrete solve interface dZygote = Zygote.gradient(input -> discrete_cost_forward(input, BacksolveAdjoint()), input)[1] @test dZygote ≈ dForwardDiff -dZygote = Zygote.gradient(input -> discrete_cost_forward(input, InterpolatingAdjoint()), input)[1] +dZygote = Zygote.gradient(input -> discrete_cost_forward(input, InterpolatingAdjoint()), + input)[1] @test dZygote ≈ dForwardDiff -dZygote = Zygote.gradient(input -> discrete_cost_forward(input, QuadratureAdjoint()), input)[1] +dZygote = Zygote.gradient(input -> discrete_cost_forward(input, QuadratureAdjoint()), + input)[1] @test dZygote ≈ dForwardDiff ## - ## Mixed costs function mixed_cost_forward(input) - u0 = input[1:2] - p = input[3:end] - - prob = ODEProblem(fiip, u0, (0.0, 10.0), p) - sol = solve(prob, Tsit5(), abstol=abstol, reltol=reltol, save_start=false, save_end=false) - cost, err = quadgk((t) -> sol(t)[1]^2 + p[1], prob.tspan..., atol=abstol, rtol=reltol) - for t in savingtimes - cost += (sol(t)[1]^2) - end - return cost + u0 = input[1:2] + p = input[3:end] + + prob = ODEProblem(fiip, u0, (0.0, 10.0), p) + sol = solve(prob, Tsit5(), abstol = abstol, reltol = reltol, save_start = false, + save_end = false) + cost, err = quadgk((t) -> sol(t)[1]^2 + p[1], prob.tspan..., atol = abstol, + rtol = reltol) + for t in savingtimes + cost += (sol(t)[1]^2) + end + return cost end dForwardDiff = ForwardDiff.gradient(mixed_cost_forward, input) dFiniteDiff = FiniteDiff.finite_difference_gradient(mixed_cost_forward, input) @test dForwardDiff ≈ dFiniteDiff function dgdu_discrete(out, u, p, t, i) - out[1] = 2u[1] - out[2] = 0.0 + out[1] = 2u[1] + out[2] = 0.0 end function dgdu_continuous(out, u, p, t) - out[1] = 2u[1] - out[2] = 0.0 + out[1] = 2u[1] + out[2] = 0.0 end function dgdp_continuous(out, u, p, t) - out[1] = 1.0 - out[2] = 0.0 - out[3] = 0.0 - out[4] = 0.0 + out[1] = 1.0 + out[2] = 0.0 + out[3] = 0.0 + out[4] = 0.0 end # BacksolveAdjoint, all vjps -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=BacksolveAdjoint(autojacvec=TrackerVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = BacksolveAdjoint(autojacvec = TrackerVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=BacksolveAdjoint(autojacvec=false), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = BacksolveAdjoint(autojacvec = false), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] # InterpolatingAdjoint, all vjps -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=InterpolatingAdjoint(autojacvec=TrackerVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = InterpolatingAdjoint(autojacvec = TrackerVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=InterpolatingAdjoint(autojacvec=false), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = InterpolatingAdjoint(autojacvec = false), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] # QuadratureAdjoint, all vjps -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=QuadratureAdjoint(autojacvec=EnzymeVJP(), abstol=abstol, reltol=reltol), abstol=abstol, reltol=reltol) -@test du0 ≈ dForwardDiff[1:2] -@test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(), abstol=abstol, reltol=reltol), abstol=abstol, reltol=reltol) -@test du0 ≈ dForwardDiff[1:2] -@test dp' ≈ dForwardDiff[3:6] -du0, dp = adjoint_sensitivities(sol, Tsit5(), t=savingtimes, dg_discrete=dgdu_discrete, dg_continuous=(dgdu_continuous, dgdp_continuous), sensealg=QuadratureAdjoint(autojacvec=false, abstol=abstol, reltol=reltol), abstol=abstol, reltol=reltol) +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = QuadratureAdjoint(autojacvec = EnzymeVJP(), + abstol = abstol, + reltol = reltol), + abstol = abstol, reltol = reltol) +@test du0 ≈ dForwardDiff[1:2] +@test dp' ≈ dForwardDiff[3:6] +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(), + abstol = abstol, + reltol = reltol), + abstol = abstol, reltol = reltol) +@test du0 ≈ dForwardDiff[1:2] +@test dp' ≈ dForwardDiff[3:6] +du0, dp = adjoint_sensitivities(sol, Tsit5(), t = savingtimes, dg_discrete = dgdu_discrete, + dg_continuous = (dgdu_continuous, dgdp_continuous), + sensealg = QuadratureAdjoint(autojacvec = false, + abstol = abstol, + reltol = reltol), + abstol = abstol, reltol = reltol) @test du0 ≈ dForwardDiff[1:2] @test dp' ≈ dForwardDiff[3:6] diff --git a/test/null_parameters.jl b/test/null_parameters.jl index 0ac54e8b7..ca03b33b5 100644 --- a/test/null_parameters.jl +++ b/test/null_parameters.jl @@ -7,63 +7,72 @@ dynamics = (x, _p, _t) -> x function loss(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0), params) - rollout = solve(problem, Tsit5(), u0 = u0, p = params, sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(allow_nothing=true))) + rollout = solve(problem, Tsit5(), u0 = u0, p = params, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP(allow_nothing = true))) sum(Array(rollout)[:, end]) end function loss2(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0), params) - rollout = solve(problem, Tsit5(), u0 = u0, p = params, sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) + rollout = solve(problem, Tsit5(), u0 = u0, p = params, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) sum(Array(rollout)[:, end]) end function loss3(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0), params) - rollout = solve(problem, Tsit5(), u0 = u0, p = params, sensealg = InterpolatingAdjoint(autojacvec=TrackerVJP(allow_nothing=true))) + rollout = solve(problem, Tsit5(), u0 = u0, p = params, + sensealg = InterpolatingAdjoint(autojacvec = TrackerVJP(allow_nothing = true))) sum(Array(rollout)[:, end]) end function loss4(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0)) - rollout = solve(problem, Tsit5(), u0 = u0, p = params, sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(allow_nothing=true))) + rollout = solve(problem, Tsit5(), u0 = u0, p = params, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP(allow_nothing = true))) sum(Array(rollout)[:, end]) end function loss5(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0)) - rollout = solve(problem, Tsit5(), u0 = u0, p = params, sensealg = InterpolatingAdjoint(autojacvec=EnzymeVJP())) + rollout = solve(problem, Tsit5(), u0 = u0, p = params, + sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP())) sum(Array(rollout)[:, end]) end function loss6(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0)) - rollout = solve(problem, Tsit5(), u0 = u0, p = params, sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP(allow_nothing=true))) + rollout = solve(problem, Tsit5(), u0 = u0, p = params, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP(allow_nothing = true))) sum(Array(rollout)[:, end]) end function loss7(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0)) - rollout = solve(problem, Tsit5(), u0 = u0, p = params, sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP(allow_nothing=true))) + rollout = solve(problem, Tsit5(), u0 = u0, p = params, + sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP(allow_nothing = true))) sum(Array(rollout)[:, end]) end function loss8(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0)) - rollout = solve(problem, Tsit5(), u0 = u0, p = params, sensealg = QuadratureAdjoint(autojacvec=ReverseDiffVJP())) + rollout = solve(problem, Tsit5(), u0 = u0, p = params, + sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP())) sum(Array(rollout)[:, end]) end function loss9(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0)) - rollout = solve(problem, Tsit5(), u0 = u0, p = params, sensealg = QuadratureAdjoint(autojacvec=EnzymeVJP())) + rollout = solve(problem, Tsit5(), u0 = u0, p = params, + sensealg = QuadratureAdjoint(autojacvec = EnzymeVJP())) sum(Array(rollout)[:, end]) end @@ -74,7 +83,7 @@ function loss10(params) sum(Array(rollout)[:, end]) end -@test Zygote.gradient(dynamics, 0.0, nothing, nothing) == (1.0,nothing,nothing) +@test Zygote.gradient(dynamics, 0.0, nothing, nothing) == (1.0, nothing, nothing) @test Zygote.gradient(loss, nothing)[1] === nothing @test_broken Zygote.gradient(loss2, nothing) @@ -96,4 +105,4 @@ end @test_broken Zygote.gradient(loss7, zeros(123))[1] == zeros(123) @test Zygote.gradient(loss8, zeros(123))[1] == zeros(123) @test Zygote.gradient(loss9, zeros(123))[1] == zeros(123) -@test_throws SciMLSensitivity.ZygoteVJPNothingError Zygote.gradient(loss10, zeros(123))[1] == zeros(123) +@test_throws SciMLSensitivity.ZygoteVJPNothingError Zygote.gradient(loss10, zeros(123))[1]==zeros(123) diff --git a/test/parameter_compatibility_errors.jl b/test/parameter_compatibility_errors.jl index e8ac92cf0..7b47f03e3 100644 --- a/test/parameter_compatibility_errors.jl +++ b/test/parameter_compatibility_errors.jl @@ -1,50 +1,55 @@ using OrdinaryDiffEq, SciMLSensitivity, Zygote, Test -function f!(du,u,p,t) - du[1] = -p[1]'*u +function f!(du, u, p, t) + du[1] = -p[1]' * u du[2] = (p[2].a + p[2].b)u[2] - du[3] = p[3](u,t) + du[3] = p[3](u, t) return nothing end struct mystruct - a - b + a::Any + b::Any end -function control(u,t) - return -exp(-t)*u[3] +function control(u, t) + return -exp(-t) * u[3] end -u0 = [10,15,20] -p = [[1;2;3], mystruct(-1,-2), control] -tspan = (0.0,10.0) +u0 = [10, 15, 20] +p = [[1; 2; 3], mystruct(-1, -2), control] +tspan = (0.0, 10.0) -prob = ODEProblem(f!,u0, tspan, p) +prob = ODEProblem(f!, u0, tspan, p) sol = solve(prob, Tsit5()) # Solves without errors function loss(p1) - sol = solve(prob, Tsit5(), p=[p1, mystruct(-1,-2), control]) + sol = solve(prob, Tsit5(), p = [p1, mystruct(-1, -2), control]) return sum(abs2, sol) end grad(p) = Zygote.gradient(loss, p) -p2 = [4;5;6] +p2 = [4; 5; 6] @test_throws SciMLSensitivity.ForwardDiffSensitivityParameterCompatibilityError grad(p2) function loss(p1) - sol = solve(prob, Tsit5(), p=[p1, mystruct(-1,-2), control], sensealg = InterpolatingAdjoint()) + sol = solve(prob, Tsit5(), p = [p1, mystruct(-1, -2), control], + sensealg = InterpolatingAdjoint()) return sum(abs2, sol) end @test_throws SciMLSensitivity.AdjointSensitivityParameterCompatibilityError grad(p2) function loss(p1) - sol = solve(prob, Tsit5(), p=[p1, mystruct(-1,-2), control], sensealg = ForwardSensitivity()) + sol = solve(prob, Tsit5(), p = [p1, mystruct(-1, -2), control], + sensealg = ForwardSensitivity()) return sum(abs2, sol) end @test_throws SciMLSensitivity.ForwardSensitivityParameterCompatibilityError grad(p2) -@test_throws SciMLSensitivity.ForwardSensitivityParameterCompatibilityError ODEForwardSensitivityProblem(f!,u0, tspan, p) \ No newline at end of file +@test_throws SciMLSensitivity.ForwardSensitivityParameterCompatibilityError ODEForwardSensitivityProblem(f!, + u0, + tspan, + p) diff --git a/test/partial_neural.jl b/test/partial_neural.jl index 0f44f24ef..9de382bc9 100644 --- a/test/partial_neural.jl +++ b/test/partial_neural.jl @@ -1,37 +1,38 @@ -using SciMLSensitivity, Flux, Optimization, OptimizationFlux, OptimizationOptimJL, OrdinaryDiffEq, Test - +using SciMLSensitivity, Flux, Optimization, OptimizationFlux, OptimizationOptimJL, + OrdinaryDiffEq, Test x = Float32[0.8; 0.8] -tspan = (0.0f0,10.0f0) +tspan = (0.0f0, 10.0f0) -ann = Chain(Dense(2,10,tanh), Dense(10,1)) -p = Float32[-2.0,1.1] -p2,re = Flux.destructure(ann) -_p = [p;p2] -θ = [x;_p] +ann = Chain(Dense(2, 10, tanh), Dense(10, 1)) +p = Float32[-2.0, 1.1] +p2, re = Flux.destructure(ann) +_p = [p; p2] +θ = [x; _p] -function dudt2_(u,p,t) +function dudt2_(u, p, t) x, y = u - [(re(p[3:end])(u)[1]),p[1]*y + p[2]*x] + [(re(p[3:end])(u)[1]), p[1] * y + p[2] * x] end -prob = ODEProblem(dudt2_,x,tspan,_p) -solve(prob,Tsit5()) +prob = ODEProblem(dudt2_, x, tspan, _p) +solve(prob, Tsit5()) function predict_rd(θ) - Array(solve(prob,Tsit5(),u0=θ[1:2],p=θ[3:end],abstol=1e-7,reltol=1e-5,sensealg=TrackerAdjoint())) + Array(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:end], abstol = 1e-7, reltol = 1e-5, + sensealg = TrackerAdjoint())) end -loss_rd(p) = sum(abs2,x-1 for x in predict_rd(p)) +loss_rd(p) = sum(abs2, x - 1 for x in predict_rd(p)) l = loss_rd(θ) -cb = function (θ,l) - @show l - # display(plot(solve(remake(prob,u0=Flux.data(_x),p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6))) - false +cb = function (θ, l) + @show l + # display(plot(solve(remake(prob,u0=Flux.data(_x),p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6))) + false end # Display the ODE with the current parameter values. -cb(θ,l) +cb(θ, l) loss1 = loss_rd(θ) optfunc = Optimization.OptimizationFunction((x, p) -> loss_rd(x), Optimization.AutoZygote()) @@ -43,40 +44,41 @@ loss2 = res.minimum ## Partial Neural Adjoint u0 = Float32[0.8; 0.8] -tspan = (0.0f0,25.0f0) +tspan = (0.0f0, 25.0f0) -ann = Chain(Dense(2,10,tanh), Dense(10,1)) +ann = Chain(Dense(2, 10, tanh), Dense(10, 1)) -p1,re = Flux.destructure(ann) -p2 = Float32[-2.0,1.1] -p3 = [p1;p2] -θ = [u0;p3] +p1, re = Flux.destructure(ann) +p2 = Float32[-2.0, 1.1] +p3 = [p1; p2] +θ = [u0; p3] -function dudt_(du,u,p,t) +function dudt_(du, u, p, t) x, y = u du[1] = re(p[1:41])(u)[1] - du[2] = p[end-1]*y + p[end]*x + du[2] = p[end - 1] * y + p[end] * x end -prob = ODEProblem(dudt_,u0,tspan,p3) -solve(prob,Tsit5(),abstol=1e-8,reltol=1e-6) +prob = ODEProblem(dudt_, u0, tspan, p3) +solve(prob, Tsit5(), abstol = 1e-8, reltol = 1e-6) function predict_adjoint(θ) - Array(solve(prob,Tsit5(),u0=θ[1:2],p=θ[3:end],saveat=0.0:1:25.0)) + Array(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:end], saveat = 0.0:1:25.0)) end -loss_adjoint(θ) = sum(abs2,x-1 for x in predict_adjoint(θ)) +loss_adjoint(θ) = sum(abs2, x - 1 for x in predict_adjoint(θ)) l = loss_adjoint(θ) -cb = function (θ,l) - @show l - # display(plot(solve(remake(prob,p=Flux.data(p3),u0=Flux.data(u0)),Tsit5(),saveat=0.1),ylim=(0,6))) - false +cb = function (θ, l) + @show l + # display(plot(solve(remake(prob,p=Flux.data(p3),u0=Flux.data(u0)),Tsit5(),saveat=0.1),ylim=(0,6))) + false end # Display the ODE with the current parameter values. -cb(θ,l) +cb(θ, l) loss1 = loss_adjoint(θ) -optfunc = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), Optimization.AutoZygote()) +optfunc = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), + Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optfunc, θ) res1 = Optimization.solve(optprob, ADAM(0.01), callback = cb, maxiters = 100) diff --git a/test/prob_kwargs.jl b/test/prob_kwargs.jl index ec39de417..a430a9dbc 100644 --- a/test/prob_kwargs.jl +++ b/test/prob_kwargs.jl @@ -1,29 +1,31 @@ -using OrdinaryDiffEq, SciMLSensitivity - -function growth(du, u, p, t) - @. du = p * u * (1 - u) -end -u0 = [0.1] -tspan = (0.0, 2.0) -prob = ODEProblem(growth, u0, tspan, [1.0]) -sol = solve(prob, Tsit5(), reltol = 1e-8, abstol = 1e-8) - -savetimes = [0.0, 1.0, 1.9] - -function f(a) - _prob = remake(prob,p=[a[1]],saveat=savetimes) - predicted = solve(_prob, Tsit5(), sensealg=InterpolatingAdjoint(), abstol=1e-12, reltol=1e-12) - sum(predicted[end]) -end - -function f2(a) - _prob = remake(prob,p=[a[1]],saveat=savetimes) - predicted = solve(_prob, Tsit5(), sensealg=ForwardDiffSensitivity(), abstol=1e-12, reltol=1e-12) - sum(predicted[end]) -end - -using Zygote -a = ones(3) -@test Zygote.gradient(f,a)[1][1] ≈ Zygote.gradient(f2,a)[1][1] -@test Zygote.gradient(f,a)[1][2] == Zygote.gradient(f2,a)[1][2] == 0 -@test Zygote.gradient(f,a)[1][3] == Zygote.gradient(f2,a)[1][3] == 0 +using OrdinaryDiffEq, SciMLSensitivity + +function growth(du, u, p, t) + @. du = p * u * (1 - u) +end +u0 = [0.1] +tspan = (0.0, 2.0) +prob = ODEProblem(growth, u0, tspan, [1.0]) +sol = solve(prob, Tsit5(), reltol = 1e-8, abstol = 1e-8) + +savetimes = [0.0, 1.0, 1.9] + +function f(a) + _prob = remake(prob, p = [a[1]], saveat = savetimes) + predicted = solve(_prob, Tsit5(), sensealg = InterpolatingAdjoint(), abstol = 1e-12, + reltol = 1e-12) + sum(predicted[end]) +end + +function f2(a) + _prob = remake(prob, p = [a[1]], saveat = savetimes) + predicted = solve(_prob, Tsit5(), sensealg = ForwardDiffSensitivity(), abstol = 1e-12, + reltol = 1e-12) + sum(predicted[end]) +end + +using Zygote +a = ones(3) +@test Zygote.gradient(f, a)[1][1] ≈ Zygote.gradient(f2, a)[1][1] +@test Zygote.gradient(f, a)[1][2] == Zygote.gradient(f2, a)[1][2] == 0 +@test Zygote.gradient(f, a)[1][3] == Zygote.gradient(f2, a)[1][3] == 0 diff --git a/test/rode.jl b/test/rode.jl index fd146d601..4e5a517bc 100644 --- a/test/rode.jl +++ b/test/rode.jl @@ -8,460 +8,521 @@ using Test seed = 12345 Random.seed!(seed) -function g(u,p,t) - sum(u.^2.0/2.0) +function g(u, p, t) + sum(u .^ 2.0 / 2.0) end -function dg!(out,u,p,t,i) - (out.=u) +function dg!(out, u, p, t, i) + (out .= u) end @testset "noise iip tests" begin - function f(du,u,p,t,W) - du[1] = p[1]*u[1]*sin(W[1] - W[2]) - du[2] = p[2]*u[2]*cos(W[1] + W[2]) - return nothing - end - dt = 1e-4 - u0 = [1.00;1.00] - tspan = (0.0,5.0) - t = tspan[1]:0.1:tspan[2] - p = [2.0,-2.0] - prob = RODEProblem(f,u0,tspan,p) - - sol = solve(prob,RandomEM(),dt=dt, save_noise=true) - # check reversion with usage of Noise Grid - _sol = deepcopy(sol) - noise_reverse = NoiseGrid(reverse(_sol.t),reverse(_sol.W.W)) - prob_reverse = RODEProblem(f,_sol[end],reverse(tspan),p,noise=noise_reverse) - sol_reverse = solve(prob_reverse,RandomEM(),dt=dt) - @test sol.u ≈ reverse(sol_reverse.u) rtol=1e-3 - @show minimum(sol.u) - - # Test if Forward and ReverseMode AD agree. - Random.seed!(seed) - du0ReverseDiff,dpReverseDiff = Zygote.gradient((u0,p)->sum( - Array(solve(prob,RandomEM(),dt=dt,u0=u0,p=p,saveat=t,sensealg=ReverseDiffAdjoint())).^2/2) - ,u0,p) - Random.seed!(seed) - dForward = ForwardDiff.gradient((θ)->sum( - Array(solve(prob,RandomEM(),dt=dt,u0=θ[1:2],p=θ[3:4],saveat=t)).^2/2) - ,[u0;p]) - - @info dForward - - @test du0ReverseDiff ≈ dForward[1:2] - @test dpReverseDiff ≈ dForward[3:4] - - # test gradients - Random.seed!(seed) - sol = solve(prob,RandomEM(),dt=dt, save_noise=true, saveat=t) - - - ### - ## BacksolveAdjoint - ### - - # ReverseDiff - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint()) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - @info du0, dp' - - # ReverseDiff with compiled tape - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP(true))) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # Tracker - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=TrackerVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # isautojacvec = false - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false)) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # isautojacvec = false and with jac and paramjac - function jac(J,u,p,t,W) - J[1,1] = p[1]*sin(W[1] - W[2]) - J[2,1] = zero(u[1]) - J[1,2] = zero(u[1]) - J[2,2] = p[2]*cos(W[1] + W[2]) - end - - function paramjac(J,u,p,t,W) - J[1,1] = u[1]*sin(W[1] - W[2]) - J[2,1] = zero(u[1]) - J[1,2] = zero(u[1]) - J[2,2] = u[2]*cos(W[1] + W[2]) - end - Random.seed!(seed) - faug = RODEFunction(f,jac=jac,paramjac=paramjac) - prob_aug = RODEProblem{true}(faug,u0,tspan,p) - sol = solve(prob_aug,RandomEM(),dt=dt, save_noise=true, saveat=t) - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false)) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - ### - ## InterpolatingAdjoint - ### - - # test gradients with dense solution and no checkpointing - Random.seed!(seed) - sol = solve(prob,RandomEM(),dt=dt, save_noise=true, dense=true) - - # ReverseDiff - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # ReverseDiff with compiled tape - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # Zygote - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # Tracker - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=TrackerVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # isautojacvec = false - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false)) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # isautojacvec = false and with jac and paramjac - Random.seed!(seed) - faug = RODEFunction(f,jac=jac,paramjac=paramjac) - prob_aug = RODEProblem{true}(faug,u0,tspan,p) - sol = solve(prob_aug,RandomEM(),dt=dt, save_noise=true, dense=true) - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false)) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # test gradients with saveat solution and checkpointing - # need to simulate for dt beyond last tspan to avoid errors in NoiseGrid - Random.seed!(seed) - sol = solve(prob,RandomEM(),dt=dt, save_noise=true, dense=true) - Random.seed!(seed) - sol_long = solve(remake(prob, tspan=(tspan[1],tspan[2]+10dt)),RandomEM(),dt=dt, save_noise=true, dense=true) - - @test sol_long(t) ≈ sol(t) rtol=1e-12 - @test sol_long.W.W[1:end-10] ≈ sol.W.W[1:end] rtol=1e-12 - - # test gradients with saveat solution and checkpointing - noise = NoiseGrid(sol_long.W.t,sol_long.W.W) - sol2 = solve(remake(prob,noise=noise,tspan=(tspan[1],tspan[2])),RandomEM(),dt=dt, saveat=t) - - @test sol_long(t) ≈ sol2(t) rtol=1e-12 - @test sol_long.W.W ≈ sol2.W.W rtol=1e-12 - - # ReverseDiff - du0, dp = adjoint_sensitivities(sol2,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=ReverseDiffVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # ReverseDiff with compiled tape - du0, dp = adjoint_sensitivities(sol2,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=ReverseDiffVJP(true))) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # Zygote - du0, dp = adjoint_sensitivities(sol2,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=ZygoteVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # Tracker - du0, dp = adjoint_sensitivities(sol2,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=TrackerVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # isautojacvec = false - du0, dp = adjoint_sensitivities(sol2,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=false)) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # isautojacvec = false and with jac and paramjac - Random.seed!(seed) - faug = RODEFunction(f,jac=jac,paramjac=paramjac) - prob_aug = RODEProblem{true}(faug,u0,tspan,p, noise=noise) - sol = solve(prob_aug,RandomEM(),dt=dt, save_noise=false, dense=false, saveat=t) - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=false)) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 + function f(du, u, p, t, W) + du[1] = p[1] * u[1] * sin(W[1] - W[2]) + du[2] = p[2] * u[2] * cos(W[1] + W[2]) + return nothing + end + dt = 1e-4 + u0 = [1.00; 1.00] + tspan = (0.0, 5.0) + t = tspan[1]:0.1:tspan[2] + p = [2.0, -2.0] + prob = RODEProblem(f, u0, tspan, p) + + sol = solve(prob, RandomEM(), dt = dt, save_noise = true) + # check reversion with usage of Noise Grid + _sol = deepcopy(sol) + noise_reverse = NoiseGrid(reverse(_sol.t), reverse(_sol.W.W)) + prob_reverse = RODEProblem(f, _sol[end], reverse(tspan), p, noise = noise_reverse) + sol_reverse = solve(prob_reverse, RandomEM(), dt = dt) + @test sol.u≈reverse(sol_reverse.u) rtol=1e-3 + @show minimum(sol.u) + + # Test if Forward and ReverseMode AD agree. + Random.seed!(seed) + du0ReverseDiff, dpReverseDiff = Zygote.gradient((u0, p) -> sum(Array(solve(prob, + RandomEM(), + dt = dt, + u0 = u0, + p = p, + saveat = t, + sensealg = ReverseDiffAdjoint())) .^ + 2 / 2), u0, p) + Random.seed!(seed) + dForward = ForwardDiff.gradient((θ) -> sum(Array(solve(prob, RandomEM(), dt = dt, + u0 = θ[1:2], p = θ[3:4], + saveat = t)) .^ 2 / 2), [u0; p]) + + @info dForward + + @test du0ReverseDiff ≈ dForward[1:2] + @test dpReverseDiff ≈ dForward[3:4] + + # test gradients + Random.seed!(seed) + sol = solve(prob, RandomEM(), dt = dt, save_noise = true, saveat = t) + + ### + ## BacksolveAdjoint + ### + + # ReverseDiff + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint()) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + @info du0, dp' + + # ReverseDiff with compiled tape + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP(true))) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # Tracker + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = TrackerVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # isautojacvec = false + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false)) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # isautojacvec = false and with jac and paramjac + function jac(J, u, p, t, W) + J[1, 1] = p[1] * sin(W[1] - W[2]) + J[2, 1] = zero(u[1]) + J[1, 2] = zero(u[1]) + J[2, 2] = p[2] * cos(W[1] + W[2]) + end + + function paramjac(J, u, p, t, W) + J[1, 1] = u[1] * sin(W[1] - W[2]) + J[2, 1] = zero(u[1]) + J[1, 2] = zero(u[1]) + J[2, 2] = u[2] * cos(W[1] + W[2]) + end + Random.seed!(seed) + faug = RODEFunction(f, jac = jac, paramjac = paramjac) + prob_aug = RODEProblem{true}(faug, u0, tspan, p) + sol = solve(prob_aug, RandomEM(), dt = dt, save_noise = true, saveat = t) + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false)) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + ### + ## InterpolatingAdjoint + ### + + # test gradients with dense solution and no checkpointing + Random.seed!(seed) + sol = solve(prob, RandomEM(), dt = dt, save_noise = true, dense = true) + + # ReverseDiff + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # ReverseDiff with compiled tape + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true))) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # Zygote + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # Tracker + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = TrackerVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # isautojacvec = false + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false)) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # isautojacvec = false and with jac and paramjac + Random.seed!(seed) + faug = RODEFunction(f, jac = jac, paramjac = paramjac) + prob_aug = RODEProblem{true}(faug, u0, tspan, p) + sol = solve(prob_aug, RandomEM(), dt = dt, save_noise = true, dense = true) + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false)) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # test gradients with saveat solution and checkpointing + # need to simulate for dt beyond last tspan to avoid errors in NoiseGrid + Random.seed!(seed) + sol = solve(prob, RandomEM(), dt = dt, save_noise = true, dense = true) + Random.seed!(seed) + sol_long = solve(remake(prob, tspan = (tspan[1], tspan[2] + 10dt)), RandomEM(), dt = dt, + save_noise = true, dense = true) + + @test sol_long(t)≈sol(t) rtol=1e-12 + @test sol_long.W.W[1:(end - 10)]≈sol.W.W[1:end] rtol=1e-12 + + # test gradients with saveat solution and checkpointing + noise = NoiseGrid(sol_long.W.t, sol_long.W.W) + sol2 = solve(remake(prob, noise = noise, tspan = (tspan[1], tspan[2])), RandomEM(), + dt = dt, saveat = t) + + @test sol_long(t)≈sol2(t) rtol=1e-12 + @test sol_long.W.W≈sol2.W.W rtol=1e-12 + + # ReverseDiff + du0, dp = adjoint_sensitivities(sol2, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = ReverseDiffVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # ReverseDiff with compiled tape + du0, dp = adjoint_sensitivities(sol2, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = ReverseDiffVJP(true))) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # Zygote + du0, dp = adjoint_sensitivities(sol2, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = ZygoteVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # Tracker + du0, dp = adjoint_sensitivities(sol2, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = TrackerVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # isautojacvec = false + du0, dp = adjoint_sensitivities(sol2, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = false)) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # isautojacvec = false and with jac and paramjac + Random.seed!(seed) + faug = RODEFunction(f, jac = jac, paramjac = paramjac) + prob_aug = RODEProblem{true}(faug, u0, tspan, p, noise = noise) + sol = solve(prob_aug, RandomEM(), dt = dt, save_noise = false, dense = false, + saveat = t) + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = false)) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 end - @testset "noise oop tests" begin - function f(u,p,t,W) - dx = p[1]*u[1]*sin(W[1] - W[2]) - dy = p[2]*u[2]*cos(W[1] + W[2]) - return [dx,dy] - end - dt = 1e-4 - u0 = [1.00;1.00] - tspan = (0.0,5.0) - t = tspan[1]:0.1:tspan[2] - p = [2.0,-2.0] - prob = RODEProblem{false}(f,u0,tspan,p) - - sol = solve(prob,RandomEM(),dt=dt, save_noise=true) - # check reversion with usage of Noise Grid - _sol = deepcopy(sol) - noise_reverse = NoiseGrid(reverse(_sol.t),reverse(_sol.W.W)) - prob_reverse = RODEProblem(f,_sol[end],reverse(tspan),p,noise=noise_reverse) - sol_reverse = solve(prob_reverse,RandomEM(),dt=dt) - @test sol.u ≈ reverse(sol_reverse.u) rtol=1e-3 - @show minimum(sol.u) - - # Test if Forward and ReverseMode AD agree. - Random.seed!(seed) - du0ReverseDiff,dpReverseDiff = Zygote.gradient((u0,p)->sum( - Array(solve(prob,RandomEM(),dt=dt,u0=u0,p=p,saveat=t,sensealg=ReverseDiffAdjoint())).^2/2) - ,u0,p) - Random.seed!(seed) - dForward = ForwardDiff.gradient((θ)->sum( - Array(solve(prob,RandomEM(),dt=dt,u0=θ[1:2],p=θ[3:4],saveat=t)).^2/2) - ,[u0;p]) - - @info dForward - - @test du0ReverseDiff ≈ dForward[1:2] - @test dpReverseDiff ≈ dForward[3:4] - - # test gradients - Random.seed!(seed) - sol = solve(prob,RandomEM(),dt=dt, save_noise=true, saveat=t) - - ### - ## BacksolveAdjoint - ### - - # ReverseDiff - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - @info du0, dp' - - # ReverseDiff with compiled tape - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP(true))) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # Zygote - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # Tracker - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=TrackerVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # isautojacvec = false - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false)) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # isautojacvec = false and with jac and paramjac - function jac(J,u,p,t,W) - J[1,1] = p[1]*sin(W[1] - W[2]) - J[2,1] = zero(u[1]) - J[1,2] = zero(u[1]) - J[2,2] = p[2]*cos(W[1] + W[2]) - end - - function paramjac(J,u,p,t,W) - J[1,1] = u[1]*sin(W[1] - W[2]) - J[2,1] = zero(u[1]) - J[1,2] = zero(u[1]) - J[2,2] = u[2]*cos(W[1] + W[2]) - end - Random.seed!(seed) - faug = RODEFunction(f,jac=jac,paramjac=paramjac) - prob_aug = RODEProblem{false}(faug,u0,tspan,p) - sol = solve(prob_aug,RandomEM(),dt=dt, save_noise=true, saveat=t) - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false)) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - ### - ## InterpolatingAdjoint - ### - - # test gradients with dense solution and no checkpointing - Random.seed!(seed) - sol = solve(prob,RandomEM(),dt=dt, save_noise=true, dense=true) - - # ReverseDiff - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # ReverseDiff with compiled tape - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # Zygote - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # Tracker - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=TrackerVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # isautojacvec = false - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false)) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # isautojacvec = false and with jac and paramjac - Random.seed!(seed) - faug = RODEFunction(f,jac=jac,paramjac=paramjac) - prob_aug = RODEProblem{false}(faug,u0,tspan,p) - sol = solve(prob_aug,RandomEM(),dt=dt, save_noise=true, dense=true) - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false)) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # test gradients with saveat solution and checkpointing - # need to simulate for dt beyond last tspan to avoid errors in NoiseGrid - Random.seed!(seed) - sol = solve(prob,RandomEM(),dt=dt, save_noise=true, dense=true) - Random.seed!(seed) - sol_long = solve(remake(prob, tspan=(tspan[1],tspan[2]+10dt)),RandomEM(),dt=dt, save_noise=true, dense=true) - - @test sol_long(t) ≈ sol(t) rtol=1e-12 - @test sol_long.W.W[1:end-10] ≈ sol.W.W[1:end] rtol=1e-12 - - # test gradients with saveat solution and checkpointing - noise = NoiseGrid(sol_long.W.t,sol_long.W.W) - sol2 = solve(remake(prob,noise=noise,tspan=(tspan[1],tspan[2])),RandomEM(),dt=dt, saveat=t) - - @test sol_long(t) ≈ sol2(t) rtol=1e-12 - @test sol_long.W.W ≈ sol2.W.W rtol=1e-12 - - # ReverseDiff - du0, dp = adjoint_sensitivities(sol2,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=ReverseDiffVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # ReverseDiff with compiled tape - du0, dp = adjoint_sensitivities(sol2,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=ReverseDiffVJP(true))) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # Zygote - du0, dp = adjoint_sensitivities(sol2,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=ZygoteVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # Tracker - du0, dp = adjoint_sensitivities(sol2,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=TrackerVJP())) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # isautojacvec = false - du0, dp = adjoint_sensitivities(sol2,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=false)) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 - - # isautojacvec = false and with jac and paramjac - Random.seed!(seed) - faug = RODEFunction(f,jac=jac,paramjac=paramjac) - prob_aug = RODEProblem{false}(faug,u0,tspan,p,noise=noise) - sol = solve(prob_aug,RandomEM(),dt=dt, save_noise=false, saveat=t, dense=false) - du0, dp = adjoint_sensitivities(sol,RandomEM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=InterpolatingAdjoint(checkpointing=true,autojacvec=false)) - - @test du0ReverseDiff ≈ du0 rtol=1e-2 - @test dpReverseDiff ≈ dp' rtol=1e-2 + function f(u, p, t, W) + dx = p[1] * u[1] * sin(W[1] - W[2]) + dy = p[2] * u[2] * cos(W[1] + W[2]) + return [dx, dy] + end + dt = 1e-4 + u0 = [1.00; 1.00] + tspan = (0.0, 5.0) + t = tspan[1]:0.1:tspan[2] + p = [2.0, -2.0] + prob = RODEProblem{false}(f, u0, tspan, p) + + sol = solve(prob, RandomEM(), dt = dt, save_noise = true) + # check reversion with usage of Noise Grid + _sol = deepcopy(sol) + noise_reverse = NoiseGrid(reverse(_sol.t), reverse(_sol.W.W)) + prob_reverse = RODEProblem(f, _sol[end], reverse(tspan), p, noise = noise_reverse) + sol_reverse = solve(prob_reverse, RandomEM(), dt = dt) + @test sol.u≈reverse(sol_reverse.u) rtol=1e-3 + @show minimum(sol.u) + + # Test if Forward and ReverseMode AD agree. + Random.seed!(seed) + du0ReverseDiff, dpReverseDiff = Zygote.gradient((u0, p) -> sum(Array(solve(prob, + RandomEM(), + dt = dt, + u0 = u0, + p = p, + saveat = t, + sensealg = ReverseDiffAdjoint())) .^ + 2 / 2), u0, p) + Random.seed!(seed) + dForward = ForwardDiff.gradient((θ) -> sum(Array(solve(prob, RandomEM(), dt = dt, + u0 = θ[1:2], p = θ[3:4], + saveat = t)) .^ 2 / 2), [u0; p]) + + @info dForward + + @test du0ReverseDiff ≈ dForward[1:2] + @test dpReverseDiff ≈ dForward[3:4] + + # test gradients + Random.seed!(seed) + sol = solve(prob, RandomEM(), dt = dt, save_noise = true, saveat = t) + + ### + ## BacksolveAdjoint + ### + + # ReverseDiff + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + @info du0, dp' + + # ReverseDiff with compiled tape + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP(true))) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # Zygote + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # Tracker + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = TrackerVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # isautojacvec = false + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false)) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # isautojacvec = false and with jac and paramjac + function jac(J, u, p, t, W) + J[1, 1] = p[1] * sin(W[1] - W[2]) + J[2, 1] = zero(u[1]) + J[1, 2] = zero(u[1]) + J[2, 2] = p[2] * cos(W[1] + W[2]) + end + + function paramjac(J, u, p, t, W) + J[1, 1] = u[1] * sin(W[1] - W[2]) + J[2, 1] = zero(u[1]) + J[1, 2] = zero(u[1]) + J[2, 2] = u[2] * cos(W[1] + W[2]) + end + Random.seed!(seed) + faug = RODEFunction(f, jac = jac, paramjac = paramjac) + prob_aug = RODEProblem{false}(faug, u0, tspan, p) + sol = solve(prob_aug, RandomEM(), dt = dt, save_noise = true, saveat = t) + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false)) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + ### + ## InterpolatingAdjoint + ### + + # test gradients with dense solution and no checkpointing + Random.seed!(seed) + sol = solve(prob, RandomEM(), dt = dt, save_noise = true, dense = true) + + # ReverseDiff + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # ReverseDiff with compiled tape + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true))) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # Zygote + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # Tracker + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = TrackerVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # isautojacvec = false + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false)) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # isautojacvec = false and with jac and paramjac + Random.seed!(seed) + faug = RODEFunction(f, jac = jac, paramjac = paramjac) + prob_aug = RODEProblem{false}(faug, u0, tspan, p) + sol = solve(prob_aug, RandomEM(), dt = dt, save_noise = true, dense = true) + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false)) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # test gradients with saveat solution and checkpointing + # need to simulate for dt beyond last tspan to avoid errors in NoiseGrid + Random.seed!(seed) + sol = solve(prob, RandomEM(), dt = dt, save_noise = true, dense = true) + Random.seed!(seed) + sol_long = solve(remake(prob, tspan = (tspan[1], tspan[2] + 10dt)), RandomEM(), dt = dt, + save_noise = true, dense = true) + + @test sol_long(t)≈sol(t) rtol=1e-12 + @test sol_long.W.W[1:(end - 10)]≈sol.W.W[1:end] rtol=1e-12 + + # test gradients with saveat solution and checkpointing + noise = NoiseGrid(sol_long.W.t, sol_long.W.W) + sol2 = solve(remake(prob, noise = noise, tspan = (tspan[1], tspan[2])), RandomEM(), + dt = dt, saveat = t) + + @test sol_long(t)≈sol2(t) rtol=1e-12 + @test sol_long.W.W≈sol2.W.W rtol=1e-12 + + # ReverseDiff + du0, dp = adjoint_sensitivities(sol2, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = ReverseDiffVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # ReverseDiff with compiled tape + du0, dp = adjoint_sensitivities(sol2, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = ReverseDiffVJP(true))) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # Zygote + du0, dp = adjoint_sensitivities(sol2, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = ZygoteVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # Tracker + du0, dp = adjoint_sensitivities(sol2, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = TrackerVJP())) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # isautojacvec = false + du0, dp = adjoint_sensitivities(sol2, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = false)) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 + + # isautojacvec = false and with jac and paramjac + Random.seed!(seed) + faug = RODEFunction(f, jac = jac, paramjac = paramjac) + prob_aug = RODEProblem{false}(faug, u0, tspan, p, noise = noise) + sol = solve(prob_aug, RandomEM(), dt = dt, save_noise = false, saveat = t, + dense = false) + du0, dp = adjoint_sensitivities(sol, RandomEM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = false)) + + @test du0ReverseDiff≈du0 rtol=1e-2 + @test dpReverseDiff≈dp' rtol=1e-2 end diff --git a/test/runtests.jl b/test/runtests.jl index bceb33a78..d0cf95003 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,108 +1,107 @@ -using SciMLSensitivity, SafeTestsets -using Test, Pkg - -const GROUP = get(ENV, "GROUP", "All") - -function activate_gpu_env() - Pkg.activate("gpu") - Pkg.develop(PackageSpec(path=dirname(@__DIR__))) - Pkg.instantiate() -end - -@time begin - -if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream" - @time @safetestset "Forward Sensitivity" begin include("forward.jl") end - @time @safetestset "Sparse Adjoint Sensitivity" begin include("sparse_adjoint.jl") end - @time @safetestset "Second Order Sensitivity" begin include("second_order.jl") end - @time @safetestset "Concrete Solve Derivatives" begin include("concrete_solve_derivatives.jl") end - @time @safetestset "Branching Derivatives" begin include("branching_derivatives.jl") end - @time @safetestset "Derivative Shapes" begin include("derivative_shapes.jl") end - @time @safetestset "save_idxs" begin include("save_idxs.jl") end - @time @safetestset "ArrayPartitions" begin include("array_partitions.jl") end - @time @safetestset "Complex Adjoints" begin include("complex_adjoints.jl") end - @time @safetestset "Forward Remake" begin include("forward_remake.jl") end - @time @safetestset "Prob Kwargs" begin include("prob_kwargs.jl") end - @time @safetestset "DiscreteProblem Adjoints" begin include("discrete.jl") end - @time @safetestset "Time Type Mixing Adjoints" begin include("time_type_mixing.jl") end -end - -if GROUP == "All" || GROUP == "Core2" - @time @safetestset "hasbranching" begin include("hasbranching.jl") end - @time @safetestset "Literal Adjoint" begin include("literal_adjoint.jl") end - @time @safetestset "ForwardDiff Chunking Adjoints" begin include("forward_chunking.jl") end - @time @safetestset "Stiff Adjoints" begin include("stiff_adjoints.jl") end - @time @safetestset "Autodiff Events" begin include("autodiff_events.jl") end - @time @safetestset "Null Parameters" begin include("null_parameters.jl") end - @time @safetestset "Forward Mode Prob Kwargs" begin include("forward_prob_kwargs.jl") end - @time @safetestset "Steady State Adjoint" begin include("steady_state.jl") end - @time @safetestset "Concrete Solve Derivatives of Second Order ODEs" begin include("second_order_odes.jl") end - @time @safetestset "Parameter Compatibility Errors" begin include("parameter_compatibility_errors.jl") end -end - -if GROUP == "All" || GROUP == "Core3" || GROUP == "Downstream" - @time @safetestset "Adjoint Sensitivity" begin include("adjoint.jl") end - @time @safetestset "Continuous and discrete costs" begin include("mixed_costs.jl") end -end - -if GROUP == "All" || GROUP == "Core4" - @time @safetestset "Ensemble Tests" begin include("ensembles.jl") end - @time @safetestset "GDP Regression Tests" begin include("gdp_regression_test.jl") end - @time @safetestset "Layers Tests" begin include("layers.jl") end - @time @safetestset "Layers SDE" begin include("layers_sde.jl") end - @time @safetestset "Layers DDE" begin include("layers_dde.jl") end - @time @safetestset "SDE - Neural" begin include("sde_neural.jl") end - - # No `@safetestset` since it requires running in Main - @time @testset "Distributed" begin include("distributed.jl") end -end - -if GROUP == "All" || GROUP == "Core5" - @time @safetestset "Partial Neural Tests" begin include("partial_neural.jl") end - @time @safetestset "Size Handling in Adjoint Tests" begin include("size_handling_adjoint.jl") end - @time @safetestset "Callback - ReverseDiff" begin include("callback_reversediff.jl") end - @time @safetestset "Alternative AD Frontend" begin include("alternative_ad_frontend.jl") end - @time @safetestset "Hybrid DE" begin include("hybrid_de.jl") end - @time @safetestset "HybridNODE" begin include("HybridNODE.jl") end - @time @safetestset "ForwardDiff Sparsity Components" begin include("forwarddiffsensitivity_sparsity_components.jl") end - @time @safetestset "Complex No u" begin include("complex_no_u.jl") end -end - -if GROUP == "All" || GROUP == "SDE1" - @time @safetestset "SDE Adjoint" begin include("sde_stratonovich.jl") end - @time @safetestset "SDE Scalar Noise" begin include("sde_scalar_stratonovich.jl") end - @time @safetestset "SDE Checkpointing" begin include("sde_checkpointing.jl") end -end - -if GROUP == "All" || GROUP == "SDE2" - @time @safetestset "SDE Non-Diagonal Noise" begin include("sde_nondiag_stratonovich.jl") end -end - -if GROUP == "All" || GROUP == "SDE3" - @time @safetestset "RODE Tests" begin include("rode.jl") end - @time @safetestset "SDE Ito Conversion Tests" begin include("sde_transformation_test.jl") end - @time @safetestset "SDE Ito Scalar Noise" begin include("sde_scalar_ito.jl") end -end - -if GROUP == "Callbacks1" - @time @safetestset "Discrete Callbacks with ForwardDiffSensitivity" begin include("callbacks/forward_sensitivity_callback.jl") end - @time @safetestset "Discrete Callbacks with Adjoints" begin include("callbacks/discrete_callbacks.jl") end - @time @safetestset "SDE Callbacks" begin include("callbacks/SDE_callbacks.jl") end -end - -if GROUP == "Callbacks2" - @time @safetestset "Continuous vs. discrete Callbacks" begin include("callbacks/continuous_vs_discrete.jl") end - @time @safetestset "Continuous Callbacks with Adjoints" begin include("callbacks/continuous_callbacks.jl") end - @time @safetestset "VectorContinuousCallbacks with Adjoints" begin include("callbacks/vector_continuous_callbacks.jl") end -end - -if GROUP == "Shadowing" - @time @safetestset "Shadowing Tests" begin include("shadowing.jl") end -end - -if GROUP == "GPU" - activate_gpu_env() - @time @safetestset "Standard DiffEqFlux GPU" begin include("gpu/diffeqflux_standard_gpu.jl") end - @time @safetestset "Mixed GPU/CPU" begin include("gpu/mixed_gpu_cpu_adjoint.jl") end -end -end +using SciMLSensitivity, SafeTestsets +using Test, Pkg + +const GROUP = get(ENV, "GROUP", "All") + +function activate_gpu_env() + Pkg.activate("gpu") + Pkg.develop(PackageSpec(path = dirname(@__DIR__))) + Pkg.instantiate() +end + +@time begin + if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream" + @time @safetestset "Forward Sensitivity" begin include("forward.jl") end + @time @safetestset "Sparse Adjoint Sensitivity" begin include("sparse_adjoint.jl") end + @time @safetestset "Second Order Sensitivity" begin include("second_order.jl") end + @time @safetestset "Concrete Solve Derivatives" begin include("concrete_solve_derivatives.jl") end + @time @safetestset "Branching Derivatives" begin include("branching_derivatives.jl") end + @time @safetestset "Derivative Shapes" begin include("derivative_shapes.jl") end + @time @safetestset "save_idxs" begin include("save_idxs.jl") end + @time @safetestset "ArrayPartitions" begin include("array_partitions.jl") end + @time @safetestset "Complex Adjoints" begin include("complex_adjoints.jl") end + @time @safetestset "Forward Remake" begin include("forward_remake.jl") end + @time @safetestset "Prob Kwargs" begin include("prob_kwargs.jl") end + @time @safetestset "DiscreteProblem Adjoints" begin include("discrete.jl") end + @time @safetestset "Time Type Mixing Adjoints" begin include("time_type_mixing.jl") end + end + + if GROUP == "All" || GROUP == "Core2" + @time @safetestset "hasbranching" begin include("hasbranching.jl") end + @time @safetestset "Literal Adjoint" begin include("literal_adjoint.jl") end + @time @safetestset "ForwardDiff Chunking Adjoints" begin include("forward_chunking.jl") end + @time @safetestset "Stiff Adjoints" begin include("stiff_adjoints.jl") end + @time @safetestset "Autodiff Events" begin include("autodiff_events.jl") end + @time @safetestset "Null Parameters" begin include("null_parameters.jl") end + @time @safetestset "Forward Mode Prob Kwargs" begin include("forward_prob_kwargs.jl") end + @time @safetestset "Steady State Adjoint" begin include("steady_state.jl") end + @time @safetestset "Concrete Solve Derivatives of Second Order ODEs" begin include("second_order_odes.jl") end + @time @safetestset "Parameter Compatibility Errors" begin include("parameter_compatibility_errors.jl") end + end + + if GROUP == "All" || GROUP == "Core3" || GROUP == "Downstream" + @time @safetestset "Adjoint Sensitivity" begin include("adjoint.jl") end + @time @safetestset "Continuous and discrete costs" begin include("mixed_costs.jl") end + end + + if GROUP == "All" || GROUP == "Core4" + @time @safetestset "Ensemble Tests" begin include("ensembles.jl") end + @time @safetestset "GDP Regression Tests" begin include("gdp_regression_test.jl") end + @time @safetestset "Layers Tests" begin include("layers.jl") end + @time @safetestset "Layers SDE" begin include("layers_sde.jl") end + @time @safetestset "Layers DDE" begin include("layers_dde.jl") end + @time @safetestset "SDE - Neural" begin include("sde_neural.jl") end + + # No `@safetestset` since it requires running in Main + @time @testset "Distributed" begin include("distributed.jl") end + end + + if GROUP == "All" || GROUP == "Core5" + @time @safetestset "Partial Neural Tests" begin include("partial_neural.jl") end + @time @safetestset "Size Handling in Adjoint Tests" begin include("size_handling_adjoint.jl") end + @time @safetestset "Callback - ReverseDiff" begin include("callback_reversediff.jl") end + @time @safetestset "Alternative AD Frontend" begin include("alternative_ad_frontend.jl") end + @time @safetestset "Hybrid DE" begin include("hybrid_de.jl") end + @time @safetestset "HybridNODE" begin include("HybridNODE.jl") end + @time @safetestset "ForwardDiff Sparsity Components" begin include("forwarddiffsensitivity_sparsity_components.jl") end + @time @safetestset "Complex No u" begin include("complex_no_u.jl") end + end + + if GROUP == "All" || GROUP == "SDE1" + @time @safetestset "SDE Adjoint" begin include("sde_stratonovich.jl") end + @time @safetestset "SDE Scalar Noise" begin include("sde_scalar_stratonovich.jl") end + @time @safetestset "SDE Checkpointing" begin include("sde_checkpointing.jl") end + end + + if GROUP == "All" || GROUP == "SDE2" + @time @safetestset "SDE Non-Diagonal Noise" begin include("sde_nondiag_stratonovich.jl") end + end + + if GROUP == "All" || GROUP == "SDE3" + @time @safetestset "RODE Tests" begin include("rode.jl") end + @time @safetestset "SDE Ito Conversion Tests" begin include("sde_transformation_test.jl") end + @time @safetestset "SDE Ito Scalar Noise" begin include("sde_scalar_ito.jl") end + end + + if GROUP == "Callbacks1" + @time @safetestset "Discrete Callbacks with ForwardDiffSensitivity" begin include("callbacks/forward_sensitivity_callback.jl") end + @time @safetestset "Discrete Callbacks with Adjoints" begin include("callbacks/discrete_callbacks.jl") end + @time @safetestset "SDE Callbacks" begin include("callbacks/SDE_callbacks.jl") end + end + + if GROUP == "Callbacks2" + @time @safetestset "Continuous vs. discrete Callbacks" begin include("callbacks/continuous_vs_discrete.jl") end + @time @safetestset "Continuous Callbacks with Adjoints" begin include("callbacks/continuous_callbacks.jl") end + @time @safetestset "VectorContinuousCallbacks with Adjoints" begin include("callbacks/vector_continuous_callbacks.jl") end + end + + if GROUP == "Shadowing" + @time @safetestset "Shadowing Tests" begin include("shadowing.jl") end + end + + if GROUP == "GPU" + activate_gpu_env() + @time @safetestset "Standard DiffEqFlux GPU" begin include("gpu/diffeqflux_standard_gpu.jl") end + @time @safetestset "Mixed GPU/CPU" begin include("gpu/mixed_gpu_cpu_adjoint.jl") end + end +end diff --git a/test/save_idxs.jl b/test/save_idxs.jl index 72fe70bc6..ceded8d53 100644 --- a/test/save_idxs.jl +++ b/test/save_idxs.jl @@ -1,31 +1,32 @@ -using OrdinaryDiffEq, SciMLSensitivity, Zygote, ForwardDiff, Test - -function lotka_volterra!(du, u, p, t) - x, y = u - α, β, δ, γ = p - du[1] = dx = α*x - β*x*y - du[2] = dy = -δ*y + γ*x*y -end - -# Initial condition -u0 = [1.0, 1.0] - -# Simulation interval and intermediary points -tspan = (0.0, 10.0) -tsteps = 0.0:0.1:10.0 - -# LV equation parameter. p = [α, β, δ, γ] -p = [1.5, 1.0, 3.0, 1.0] - -# Setup the ODE problem, then solve -prob = ODEProblem(lotka_volterra!, u0, tspan, p) - -function loss(p) - sol = solve(prob, Tsit5(), p=p, save_idxs=[2], saveat = tsteps, abstol=1e-14, reltol=1e-14) - loss = sum(abs2, sol.-1) - return loss -end - -grad1 = Zygote.gradient(loss,p)[1] -grad2 = ForwardDiff.gradient(loss,p) -@test grad1 ≈ grad2 +using OrdinaryDiffEq, SciMLSensitivity, Zygote, ForwardDiff, Test + +function lotka_volterra!(du, u, p, t) + x, y = u + α, β, δ, γ = p + du[1] = dx = α * x - β * x * y + du[2] = dy = -δ * y + γ * x * y +end + +# Initial condition +u0 = [1.0, 1.0] + +# Simulation interval and intermediary points +tspan = (0.0, 10.0) +tsteps = 0.0:0.1:10.0 + +# LV equation parameter. p = [α, β, δ, γ] +p = [1.5, 1.0, 3.0, 1.0] + +# Setup the ODE problem, then solve +prob = ODEProblem(lotka_volterra!, u0, tspan, p) + +function loss(p) + sol = solve(prob, Tsit5(), p = p, save_idxs = [2], saveat = tsteps, abstol = 1e-14, + reltol = 1e-14) + loss = sum(abs2, sol .- 1) + return loss +end + +grad1 = Zygote.gradient(loss, p)[1] +grad2 = ForwardDiff.gradient(loss, p) +@test grad1 ≈ grad2 diff --git a/test/sde_checkpointing.jl b/test/sde_checkpointing.jl index ef5c7ca15..86ce1fe94 100644 --- a/test/sde_checkpointing.jl +++ b/test/sde_checkpointing.jl @@ -15,79 +15,80 @@ trange = (tstart, tend) t = tstart:dt:tend tarray = collect(t) -function g(u,p,t) - sum(u.^2.0/2.0) +function g(u, p, t) + sum(u .^ 2.0 / 2.0) end -function dg!(out,u,p,t,i) - (out.=u) +function dg!(out, u, p, t, i) + (out .= u) end -p2 = [1.01,0.87] +p2 = [1.01, 0.87] +f_oop_linear(u, p, t) = p[1] * u +σ_oop_linear(u, p, t) = p[2] * u - -f_oop_linear(u,p,t) = p[1]*u -σ_oop_linear(u,p,t) = p[2]*u - -dt1 = tend/1e3 +dt1 = tend / 1e3 Random.seed!(seed) -prob_oop = SDEProblem(f_oop_linear,σ_oop_linear,u₀,trange,p2) -sol_oop = solve(prob_oop,EulerHeun(),dt=dt1,adaptive=false,save_noise=true) +prob_oop = SDEProblem(f_oop_linear, σ_oop_linear, u₀, trange, p2) +sol_oop = solve(prob_oop, EulerHeun(), dt = dt1, adaptive = false, save_noise = true) @show length(sol_oop) -res_u0, res_p = adjoint_sensitivities(sol_oop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP())) - +res_u0, res_p = adjoint_sensitivities(sol_oop, EulerHeun(), t = tarray, dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())) -res_u0a, res_pa = adjoint_sensitivities(sol_oop,EulerHeun(),t=tarray,dg_discrete=dg!,dt=dt1, - adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()), - checkpoints=sol_oop.t[1:2:end]) +res_u0a, res_pa = adjoint_sensitivities(sol_oop, EulerHeun(), t = tarray, dg_discrete = dg!, + dt = dt1, + adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()), + checkpoints = sol_oop.t[1:2:end]) @test isapprox(res_u0, res_u0a, rtol = 1e-5) @test isapprox(res_p, res_pa, rtol = 1e-2) -res_u0a, res_pa = adjoint_sensitivities(sol_oop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()), - checkpoints=sol_oop.t[1:10:end]) +res_u0a, res_pa = adjoint_sensitivities(sol_oop, EulerHeun(), t = tarray, dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()), + checkpoints = sol_oop.t[1:10:end]) @test isapprox(res_u0, res_u0a, rtol = 1e-5) @test isapprox(res_p, res_pa, rtol = 1e-1) - - -dt1 = tend/1e4 +dt1 = tend / 1e4 Random.seed!(seed) -prob_oop = SDEProblem(f_oop_linear,σ_oop_linear,u₀,trange,p2) -sol_oop = solve(prob_oop,EulerHeun(),dt=dt1,adaptive=false,save_noise=true) +prob_oop = SDEProblem(f_oop_linear, σ_oop_linear, u₀, trange, p2) +sol_oop = solve(prob_oop, EulerHeun(), dt = dt1, adaptive = false, save_noise = true) @show length(sol_oop) -res_u0, res_p = adjoint_sensitivities(sol_oop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false, - sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())) - +res_u0, res_p = adjoint_sensitivities(sol_oop, EulerHeun(), t = tarray, dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())) -res_u0a, res_pa = adjoint_sensitivities(sol_oop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()), - checkpoints=sol_oop.t[1:2:end]) +res_u0a, res_pa = adjoint_sensitivities(sol_oop, EulerHeun(), t = tarray, dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()), + checkpoints = sol_oop.t[1:2:end]) @test isapprox(res_u0, res_u0a, rtol = 1e-6) @test isapprox(res_p, res_pa, rtol = 1e-3) -res_u0a, res_pa = adjoint_sensitivities(sol_oop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()), - checkpoints=sol_oop.t[1:10:end]) +res_u0a, res_pa = adjoint_sensitivities(sol_oop, EulerHeun(), t = tarray, dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()), + checkpoints = sol_oop.t[1:10:end]) @test isapprox(res_u0, res_u0a, rtol = 1e-6) @test isapprox(res_p, res_pa, rtol = 1e-2) -res_u0a, res_pa = adjoint_sensitivities(sol_oop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()), - checkpoints=sol_oop.t[1:500:end]) +res_u0a, res_pa = adjoint_sensitivities(sol_oop, EulerHeun(), t = tarray, dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()), + checkpoints = sol_oop.t[1:500:end]) @test isapprox(res_u0, res_u0a, rtol = 1e-3) @test isapprox(res_p, res_pa, rtol = 1e-2) diff --git a/test/sde_neural.jl b/test/sde_neural.jl index 5b1b61519..3ceb93982 100644 --- a/test/sde_neural.jl +++ b/test/sde_neural.jl @@ -16,9 +16,8 @@ Random.seed!(238248735) du[1] = e * 0.5 * (5μ - u[1]) # nutrient input time series du[2] = e * 0.05 * (10μ - u[2]) # grazer density time series du[3] = 0.2 * exp(u[1]) - 0.05 * u[3] - r * u[3] / (h + u[3]) * u[4] # nutrient concentration - du[4] = - r * u[3] / (h + u[3]) * u[4] - 0.1 * u[4] - - 0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i #Algae density + du[4] = r * u[3] / (h + u[3]) * u[4] - 0.1 * u[4] - + 0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i #Algae density end function noise!(du, u, p, t) @@ -37,21 +36,19 @@ Random.seed!(238248735) prob = SDEProblem(sys!, noise!, u0, tspan, p_) ensembleprob = EnsembleProblem(prob) - solution = solve( - ensembleprob, - SOSRI(), - EnsembleThreads(); - trajectories = 1000, - abstol = 1e-5, - reltol = 1e-5, - maxiters = 1e8, - saveat = tsteps, - ) + solution = solve(ensembleprob, + SOSRI(), + EnsembleThreads(); + trajectories = 1000, + abstol = 1e-5, + reltol = 1e-5, + maxiters = 1e8, + saveat = tsteps) (truemean, truevar) = Array.(timeseries_steps_meanvar(solution)) ann = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 2)) - α,re = Flux.destructure(ann) + α, re = Flux.destructure(ann) α = Float64.(α) function dudt_(du, u, p, t) @@ -72,9 +69,9 @@ Random.seed!(238248735) MM = re(p)(u) [e * 0.5 * (5μ - u[1]), # nutrient input time series - e * 0.05 * (10μ - u[2]), # grazer density time series - 0.2 * exp(u[1]) - 0.05 * u[3] - MM[1], # nutrient concentration - MM[2] - 0.1 * u[4] - 0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i] #Algae density + e * 0.05 * (10μ - u[2]), # grazer density time series + 0.2 * exp(u[1]) - 0.05 * u[3] - MM[1], # nutrient concentration + MM[2] - 0.1 * u[4] - 0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i] #Algae density end function noise_(du, u, p, t) @@ -87,9 +84,9 @@ Random.seed!(238248735) function noise_op(u, p, t) [p_[end], - p_[end], - 0.0, - 0.0] + p_[end], + 0.0, + 0.0] end prob_nn = SDEProblem(dudt_, noise_, u0, tspan, p = nothing) @@ -98,50 +95,49 @@ Random.seed!(238248735) function loss(θ) tmp_prob = remake(prob_nn, p = θ) ensembleprob = EnsembleProblem(tmp_prob) - tmp_sol = Array(solve( - ensembleprob, - EM(); - dt = tsteps.step, - trajectories = 100, - sensealg = ReverseDiffAdjoint(), - )) - tmp_mean = mean(tmp_sol,dims=3)[:,:] - tmp_var = var(tmp_sol,dims=3)[:,:] + tmp_sol = Array(solve(ensembleprob, + EM(); + dt = tsteps.step, + trajectories = 100, + sensealg = ReverseDiffAdjoint())) + tmp_mean = mean(tmp_sol, dims = 3)[:, :] + tmp_var = var(tmp_sol, dims = 3)[:, :] sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean end function loss_op(θ) tmp_prob = remake(prob_nn_op, p = θ) ensembleprob = EnsembleProblem(tmp_prob) - tmp_sol = Array(solve( - ensembleprob, - EM(); - dt = tsteps.step, - trajectories = 100, - sensealg = ReverseDiffAdjoint(), - )) - tmp_mean = mean(tmp_sol,dims=3)[:,:] - tmp_var = var(tmp_sol,dims=3)[:,:] + tmp_sol = Array(solve(ensembleprob, + EM(); + dt = tsteps.step, + trajectories = 100, + sensealg = ReverseDiffAdjoint())) + tmp_mean = mean(tmp_sol, dims = 3)[:, :] + tmp_var = var(tmp_sol, dims = 3)[:, :] sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean end losses = [] - callback(θ, l, pred) = begin - push!(losses, l) - if length(losses)%50 == 0 - println("Current loss after $(length(losses)) iterations: $(losses[end])") + function callback(θ, l, pred) + begin + push!(losses, l) + if length(losses) % 50 == 0 + println("Current loss after $(length(losses)) iterations: $(losses[end])") + end + false end - false end println("Test mutating form") - optf = Optimization.OptimizationFunction((x,p) -> loss(x), Optimization.AutoZygote()) + optf = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optf, α) res1 = Optimization.solve(optprob, ADAM(0.001), callback = callback, maxiters = 200) println("Test non-mutating form") - optf = Optimization.OptimizationFunction((x,p) -> loss_op(x), Optimization.AutoZygote()) + optf = Optimization.OptimizationFunction((x, p) -> loss_op(x), + Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optf, α) res2 = Optimization.solve(optprob, ADAM(0.001), callback = callback, maxiters = 200) end @@ -152,71 +148,76 @@ end # Define Neural Network for the control input input_size = x_size + 1 # size of the spatial dimensions PLUS one time dimensions - nn_initial = Chain(Dense(input_size,v_size)) # The actual neural network + nn_initial = Chain(Dense(input_size, v_size)) # The actual neural network p_nn, model = Flux.destructure(nn_initial) - nn(x,p) = model(p)(x) + nn(x, p) = model(p)(x) # Define the right hand side of the SDE const_mat = zeros(Float64, (x_size, v_size)) - for i = 1:max(x_size,v_size) - const_mat[i,i] = 1 + for i in 1:max(x_size, v_size) + const_mat[i, i] = 1 end - - function f!(du,u,p,t) - MM = nn([u;t],p) - du .= u + const_mat*MM + function f!(du, u, p, t) + MM = nn([u; t], p) + du .= u + const_mat * MM end - function g!(du,u,p,t) - du .= false*u .+ sqrt(2*0.001) + function g!(du, u, p, t) + du .= false * u .+ sqrt(2 * 0.001) end # Define SDE problem - u0 = vec(rand(Float64, (x_size,1))) + u0 = vec(rand(Float64, (x_size, 1))) tspan = (0.0, 1.0) ts = collect(0:0.1:1) prob = SDEProblem{true}(f!, g!, u0, tspan, p_nn) - W = WienerProcess(0.0,0.0,0.0) - probscalar = SDEProblem{true}(f!, g!, u0, tspan, p_nn, noise=W) + W = WienerProcess(0.0, 0.0, 0.0) + probscalar = SDEProblem{true}(f!, g!, u0, tspan, p_nn, noise = W) # Defining the loss function function loss(pars, prob, alg) function prob_func(prob, i, repeat) # Prepare new initial state and remake the problem - u0tmp = vec(rand(Float64,(x_size,1))) + u0tmp = vec(rand(Float64, (x_size, 1))) remake(prob, p = pars, u0 = u0tmp) end ensembleprob = EnsembleProblem(prob, prob_func = prob_func) - _sol = solve(ensembleprob, alg, EnsembleThreads(), sensealg = BacksolveAdjoint(), saveat = ts, trajectories = 10, - abstol=1e-1, reltol=1e-1) + _sol = solve(ensembleprob, alg, EnsembleThreads(), sensealg = BacksolveAdjoint(), + saveat = ts, trajectories = 10, + abstol = 1e-1, reltol = 1e-1) A = convert(Array, _sol) sum(abs2, A .- 1), mean(A) end # Actually training/fitting the model losses = [] - callback(θ, l, pred) = begin - push!(losses, l) - if length(losses)%1 == 0 - println("Current loss after $(length(losses)) iterations: $(losses[end])") + function callback(θ, l, pred) + begin + push!(losses, l) + if length(losses) % 1 == 0 + println("Current loss after $(length(losses)) iterations: $(losses[end])") + end + false end - false end - optf = Optimization.OptimizationFunction((p,_) -> loss(p,probscalar, LambaEM()), Optimization.AutoZygote()) + optf = Optimization.OptimizationFunction((p, _) -> loss(p, probscalar, LambaEM()), + Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optf, p_nn) res1 = Optimization.solve(optprob, ADAM(0.1), callback = callback, maxiters = 5) - optf = Optimization.OptimizationFunction((p,_) -> loss(p,probscalar, SOSRI()), Optimization.AutoZygote()) + optf = Optimization.OptimizationFunction((p, _) -> loss(p, probscalar, SOSRI()), + Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optf, p_nn) res2 = Optimization.solve(optprob, ADAM(0.1), callback = callback, maxiters = 5) - optf = Optimization.OptimizationFunction((p,_) -> loss(p,prob, LambaEM()), Optimization.AutoZygote()) + optf = Optimization.OptimizationFunction((p, _) -> loss(p, prob, LambaEM()), + Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optf, p_nn) res1 = Optimization.solve(optprob, ADAM(0.1), callback = callback, maxiters = 5) end diff --git a/test/sde_nondiag_stratonovich.jl b/test/sde_nondiag_stratonovich.jl index ba8219b7f..119ee5392 100644 --- a/test/sde_nondiag_stratonovich.jl +++ b/test/sde_nondiag_stratonovich.jl @@ -15,574 +15,667 @@ trange = (tstart, tend) t = tstart:dt:tend tarray = collect(t) -function g(u,p,t) - sum(u.^2.0/2.0) +function g(u, p, t) + sum(u .^ 2.0 / 2.0) end -function dg!(out,u,p,t,i) - (out.=u) +function dg!(out, u, p, t, i) + (out .= u) end # non-diagonal noise @testset "Non-diagonal noise tests" begin Random.seed!(seed) - u₀ = [0.75,0.5] - p = [-1.5,0.05,0.2, 0.01] + u₀ = [0.75, 0.5] + p = [-1.5, 0.05, 0.2, 0.01] - dtnd = tend/1e3 + dtnd = tend / 1e3 # Example from Roessler, SIAM J. NUMER. ANAL, 48, 922–952 with d = 2; m = 2 - function f_nondiag!(du,u,p,t) - du[1] = p[1]*u[1] + p[2]*u[2] - du[2] = p[2]*u[1] + p[1]*u[2] - nothing + function f_nondiag!(du, u, p, t) + du[1] = p[1] * u[1] + p[2] * u[2] + du[2] = p[2] * u[1] + p[1] * u[2] + nothing end - function g_nondiag!(du,u,p,t) - du[1,1] = p[3]*u[1] + p[4]*u[2] - du[1,2] = p[3]*u[1] + p[4]*u[2] - du[2,1] = p[4]*u[1] + p[3]*u[2] - du[2,2] = p[4]*u[1] + p[3]*u[2] - nothing + function g_nondiag!(du, u, p, t) + du[1, 1] = p[3] * u[1] + p[4] * u[2] + du[1, 2] = p[3] * u[1] + p[4] * u[2] + du[2, 1] = p[4] * u[1] + p[3] * u[2] + du[2, 2] = p[4] * u[1] + p[3] * u[2] + nothing end - function f_nondiag(u,p,t) - dx = p[1]*u[1] + p[2]*u[2] - dy = p[2]*u[1] + p[1]*u[2] - [dx,dy] + function f_nondiag(u, p, t) + dx = p[1] * u[1] + p[2] * u[2] + dy = p[2] * u[1] + p[1] * u[2] + [dx, dy] end - function g_nondiag(u,p,t) - du11 = p[3]*u[1] + p[4]*u[2] - du12 = p[3]*u[1] + p[4]*u[2] - du21 = p[4]*u[1] + p[3]*u[2] - du22 = p[4]*u[1] + p[3]*u[2] + function g_nondiag(u, p, t) + du11 = p[3] * u[1] + p[4] * u[2] + du12 = p[3] * u[1] + p[4] * u[2] + du21 = p[4] * u[1] + p[3] * u[2] + du22 = p[4] * u[1] + p[3] * u[2] - [du11 du12 - du21 du22] + [du11 du12 + du21 du22] end - - function f_nondiag_analytic(u0,p,t,W) - A = [[p[1], p[2]] [p[2], p[1]]] - B = [[p[3], p[4]] [p[4], p[3]]] - tmp = A*t + B*W[1] + B*W[2] - exp(tmp)*u0 + function f_nondiag_analytic(u0, p, t, W) + A = [[p[1], p[2]] [p[2], p[1]]] + B = [[p[3], p[4]] [p[4], p[3]]] + tmp = A * t + B * W[1] + B * W[2] + exp(tmp) * u0 end - noise_matrix = similar(p,2,2) + noise_matrix = similar(p, 2, 2) noise_matrix .= false Random.seed!(seed) - prob = SDEProblem(f_nondiag!,g_nondiag!,u₀,trange,p,noise_rate_prototype=noise_matrix) - sol = solve(prob, EulerHeun(), dt=dtnd, save_noise=true) + prob = SDEProblem(f_nondiag!, g_nondiag!, u₀, trange, p, + noise_rate_prototype = noise_matrix) + sol = solve(prob, EulerHeun(), dt = dtnd, save_noise = true) - noise_matrix = similar(p,2,2) + noise_matrix = similar(p, 2, 2) noise_matrix .= false Random.seed!(seed) - proboop = SDEProblem(f_nondiag,g_nondiag,u₀,trange,p,noise_rate_prototype=noise_matrix) - soloop = solve(proboop,EulerHeun(), dt=dtnd, save_noise=true) - + proboop = SDEProblem(f_nondiag, g_nondiag, u₀, trange, p, + noise_rate_prototype = noise_matrix) + soloop = solve(proboop, EulerHeun(), dt = dtnd, save_noise = true) - - res_sde_u0, res_sde_p = adjoint_sensitivities(soloop,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtnd,adaptive=false,sensealg=BacksolveAdjoint()) + res_sde_u0, res_sde_p = adjoint_sensitivities(soloop, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtnd, adaptive = false, + sensealg = BacksolveAdjoint()) @info res_sde_p - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtnd,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP())) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtnd, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP())) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-6) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-6) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-6) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-6) @info res_sde_pa - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtnd,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false)) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtnd, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false)) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-6) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-6) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-6) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-6) @info res_sde_pa - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtnd,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP())) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtnd, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-4) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-4) @info res_sde_pa - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtnd,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtnd, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-4) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-4) @info res_sde_pa - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtnd,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false)) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtnd, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false)) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-4) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-4) @info res_sde_pa - @test_broken res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtnd,adaptive=false,sensealg=BacksolveAdjoint()) + @test_broken res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), + t = Array(t), + dg_discrete = dg!, + dt = dtnd, + adaptive = false, + sensealg = BacksolveAdjoint()) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-6) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-5) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-6) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-5) @info res_sde_pa - res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtnd,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtnd, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-6) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-6) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-6) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-6) @info res_sde_pa - res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtnd,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false)) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtnd, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false)) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-6) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-6) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-6) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-6) @info res_sde_pa - @test_broken res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtnd,adaptive=false,sensealg=InterpolatingAdjoint()) + @test_broken res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), + t = Array(t), + dg_discrete = dg!, + dt = dtnd, + adaptive = false, + sensealg = InterpolatingAdjoint()) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-4) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-4) @info res_sde_pa - res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtnd,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP())) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtnd, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-4) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-4) @info res_sde_pa - res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtnd,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false)) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtnd, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false)) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-4) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-4) @info res_sde_pa function compute_grads_nd(sol) - xdis = sol(tarray) + xdis = sol(tarray) - mat1 = Matrix{Int}(I, 2, 2) - mat2 = ones(2,2)-mat1 + mat1 = Matrix{Int}(I, 2, 2) + mat2 = ones(2, 2) - mat1 - tmp1 = similar(p) - tmp1 *= false + tmp1 = similar(p) + tmp1 *= false - tmp2 = similar(xdis.u[1]) - tmp2 *= false + tmp2 = similar(xdis.u[1]) + tmp2 *= false - for (i, u) in enumerate(xdis) - tmp1[1]+=xdis.t[i]*u'*mat1*u - tmp1[2]+=xdis.t[i]*u'*mat2*u - tmp1[3]+=sum(sol.W(xdis.t[i])[1])*u'*mat1*u - tmp1[4]+=sum(sol.W(xdis.t[i])[1])*u'*mat2*u + for (i, u) in enumerate(xdis) + tmp1[1] += xdis.t[i] * u' * mat1 * u + tmp1[2] += xdis.t[i] * u' * mat2 * u + tmp1[3] += sum(sol.W(xdis.t[i])[1]) * u' * mat1 * u + tmp1[4] += sum(sol.W(xdis.t[i])[1]) * u' * mat2 * u - tmp2 += u.^2 - end + tmp2 += u .^ 2 + end - return tmp2 ./ xdis.u[1], tmp1 + return tmp2 ./ xdis.u[1], tmp1 end res1, res2 = compute_grads_nd(soloop) - @test isapprox(res1, res_sde_u0, rtol=1e-4) - @test isapprox(res2, res_sde_p', rtol=1e-4) + @test isapprox(res1, res_sde_u0, rtol = 1e-4) + @test isapprox(res2, res_sde_p', rtol = 1e-4) end - - @testset "diagonal but mixing noise tests" begin - Random.seed!(seed) - u₀ = [0.75,0.5] - p = [-1.5,0.05,0.2, 0.01] - dtmix = tend/1e3 - - # Example from Roessler, SIAM J. NUMER. ANAL, 48, 922–952 with d = 2; m = 2 - function f_mixing!(du,u,p,t) - du[1] = p[1]*u[1] + p[2]*u[2] - du[2] = p[2]*u[1] + p[1]*u[2] - nothing - end - - function g_mixing!(du,u,p,t) - du[1] = p[3]*u[1] + p[4]*u[2] - du[2] = p[3]*u[1] + p[4]*u[2] - nothing - end - - function f_mixing(u,p,t) - dx = p[1]*u[1] + p[2]*u[2] - dy = p[2]*u[1] + p[1]*u[2] - [dx,dy] - end - - function g_mixing(u,p,t) - dx = p[3]*u[1] + p[4]*u[2] - dy = p[3]*u[1] + p[4]*u[2] - [dx,dy] - end + Random.seed!(seed) + u₀ = [0.75, 0.5] + p = [-1.5, 0.05, 0.2, 0.01] + dtmix = tend / 1e3 - Random.seed!(seed) - prob = SDEProblem(f_mixing!,g_mixing!,u₀,trange,p) + # Example from Roessler, SIAM J. NUMER. ANAL, 48, 922–952 with d = 2; m = 2 + function f_mixing!(du, u, p, t) + du[1] = p[1] * u[1] + p[2] * u[2] + du[2] = p[2] * u[1] + p[1] * u[2] + nothing + end - soltsave = collect(trange[1]:dtmix:trange[2]) - sol = solve(prob, EulerHeun(), dt=dtmix, save_noise=true, saveat=soltsave) + function g_mixing!(du, u, p, t) + du[1] = p[3] * u[1] + p[4] * u[2] + du[2] = p[3] * u[1] + p[4] * u[2] + nothing + end - Random.seed!(seed) - proboop = SDEProblem(f_mixing,g_mixing,u₀,trange,p) - soloop = solve(proboop,EulerHeun(), dt=dtmix, save_noise=true, saveat=soltsave) + function f_mixing(u, p, t) + dx = p[1] * u[1] + p[2] * u[2] + dy = p[2] * u[1] + p[1] * u[2] + [dx, dy] + end + function g_mixing(u, p, t) + dx = p[3] * u[1] + p[4] * u[2] + dy = p[3] * u[1] + p[4] * u[2] + [dx, dy] + end - #oop + Random.seed!(seed) + prob = SDEProblem(f_mixing!, g_mixing!, u₀, trange, p) - res_sde_u0, res_sde_p = adjoint_sensitivities(soloop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=BacksolveAdjoint(noisemixing=true)) + soltsave = collect(trange[1]:dtmix:trange[2]) + sol = solve(prob, EulerHeun(), dt = dtmix, save_noise = true, saveat = soltsave) - @info res_sde_p + Random.seed!(seed) + proboop = SDEProblem(f_mixing, g_mixing, u₀, trange, p) + soloop = solve(proboop, EulerHeun(), dt = dtmix, save_noise = true, saveat = soltsave) - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP(), noisemixing=true)) + #oop - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-6) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-6) + res_sde_u0, res_sde_p = adjoint_sensitivities(soloop, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = BacksolveAdjoint(noisemixing = true)) - @info res_sde_pa + @info res_sde_p - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false,noisemixing=true)) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP(), + noisemixing = true)) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-6) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-6) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-6) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-6) - @info res_sde_pa + @info res_sde_pa - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP(), noisemixing=true)) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false, + noisemixing = true)) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-6) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-6) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-6) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-6) + @info res_sde_pa - @info res_sde_pa + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP(), + noisemixing = true)) - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=InterpolatingAdjoint(noisemixing=true)) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-6) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-6) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-5) + @info res_sde_pa - @info res_sde_pa + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = InterpolatingAdjoint(noisemixing = true)) - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=InterpolatingAdjoint(noisemixing=true, autojacvec=ZygoteVJP())) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-5) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-5) + @info res_sde_pa - @info res_sde_pa + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = InterpolatingAdjoint(noisemixing = true, + autojacvec = ZygoteVJP())) - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false, noisemixing=true)) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-5) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-5) + @info res_sde_pa - @info res_sde_pa + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false, + noisemixing = true)) - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(),noisemixing=true)) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-5) - res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(),noisemixing=true)) + @info res_sde_pa - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-5) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(), + noisemixing = true)) - @info res_sde_pa + res_sde_u0a, res_sde_pa = adjoint_sensitivities(soloop, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(), + noisemixing = true)) - @test_broken res_sde_u0, res_sde_p = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=BacksolveAdjoint(noisemixing=true)) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-5) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-5) + @info res_sde_pa - @info res_sde_p + @test_broken res_sde_u0, res_sde_p = adjoint_sensitivities(sol, EulerHeun(), + t = Array(t), + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = BacksolveAdjoint(noisemixing = true)) - res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP(), noisemixing=true)) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-5) + @info res_sde_p - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-6) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-6) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP(), + noisemixing = true)) - @info res_sde_pa + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-6) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-6) - res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false,noisemixing=true)) + @info res_sde_pa - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-6) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-6) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false, + noisemixing = true)) - @info res_sde_pa + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-6) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-6) - res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP(), noisemixing=true)) + @info res_sde_pa + res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP(), + noisemixing = true)) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-6) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-6) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-6) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-6) - @info res_sde_pa + @info res_sde_pa - @test_broken res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=InterpolatingAdjoint(noisemixing=true)) + @test_broken res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), + t = Array(t), + dg_discrete = dg!, + dt = dtmix, + adaptive = false, + sensealg = InterpolatingAdjoint(noisemixing = true)) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-5) # would pass with 1e-4 but last noise value is off + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-5) # would pass with 1e-4 but last noise value is off - @info res_sde_pa + @info res_sde_pa - res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=InterpolatingAdjoint(noisemixing=true, autojacvec=ZygoteVJP())) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = InterpolatingAdjoint(noisemixing = true, + autojacvec = ZygoteVJP())) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-5) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-5) - @info res_sde_pa + @info res_sde_pa - res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false, noisemixing=true)) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false, + noisemixing = true)) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-5) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-5) - @info res_sde_pa + @info res_sde_pa - res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(),noisemixing=true)) + res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(), + noisemixing = true)) - @test isapprox(res_sde_u0a, res_sde_u0, rtol=1e-5) - @test isapprox(res_sde_pa, res_sde_p, rtol=1e-5) + @test isapprox(res_sde_u0a, res_sde_u0, rtol = 1e-5) + @test isapprox(res_sde_pa, res_sde_p, rtol = 1e-5) - @info res_sde_pa + @info res_sde_pa - function GSDE(p) - Random.seed!(seed) - tmp_prob = remake(prob,u0=eltype(p).(prob.u0),p=p, - tspan=eltype(p).(prob.tspan)) - _sol = solve(tmp_prob,EulerHeun(),dt=dtmix,adaptive=false,saveat=Array(t)) - A = convert(Array,_sol) - res = g(A,p,nothing) - end + function GSDE(p) + Random.seed!(seed) + tmp_prob = remake(prob, u0 = eltype(p).(prob.u0), p = p, + tspan = eltype(p).(prob.tspan)) + _sol = solve(tmp_prob, EulerHeun(), dt = dtmix, adaptive = false, saveat = Array(t)) + A = convert(Array, _sol) + res = g(A, p, nothing) + end - res_sde_forward = ForwardDiff.gradient(GSDE,p) + res_sde_forward = ForwardDiff.gradient(GSDE, p) - @test isapprox(res_sde_p', res_sde_forward, rtol=1e-5) + @test isapprox(res_sde_p', res_sde_forward, rtol = 1e-5) - function GSDE2(u0) - Random.seed!(seed) - tmp_prob = remake(prob,u0=u0,p=eltype(p).(prob.p), - tspan=eltype(p).(prob.tspan)) - _sol = solve(tmp_prob,EulerHeun(),dt=dtmix,adaptive=false,saveat=Array(t)) - A = convert(Array,_sol) - res = g(A,p,nothing) - end + function GSDE2(u0) + Random.seed!(seed) + tmp_prob = remake(prob, u0 = u0, p = eltype(p).(prob.p), + tspan = eltype(p).(prob.tspan)) + _sol = solve(tmp_prob, EulerHeun(), dt = dtmix, adaptive = false, saveat = Array(t)) + A = convert(Array, _sol) + res = g(A, p, nothing) + end - res_sde_forward = ForwardDiff.gradient(GSDE2,u₀) + res_sde_forward = ForwardDiff.gradient(GSDE2, u₀) - @test isapprox(res_sde_forward, res_sde_u0, rtol=1e-5) + @test isapprox(res_sde_forward, res_sde_u0, rtol = 1e-5) end - @testset "mixing noise inplace/oop tests" begin - Random.seed!(seed) - u₀ = [0.75,0.5] - p = [-1.5,0.05,0.2, 0.01] - dtmix = tend/1e3 - - # Example from Roessler, SIAM J. NUMER. ANAL, 48, 922–952 with d = 2; m = 2 - function f_mixing!(du,u,p,t) - du[1] = p[1]*u[1] + p[2]*u[2] - du[2] = p[2]*u[1] + p[1]*u[2] - nothing - end - - function g_mixing!(du,u,p,t) - du[1] = p[3]*u[1] + p[4]*u[2] - du[2] = p[3]*u[1] + p[4]*u[2] - nothing - end - - function f_mixing(u,p,t) - dx = p[1]*u[1] + p[2]*u[2] - dy = p[2]*u[1] + p[1]*u[2] - [dx,dy] - end - - function g_mixing(u,p,t) - dx = p[3]*u[1] + p[4]*u[2] - dy = p[3]*u[1] + p[4]*u[2] - [dx,dy] - end - - Random.seed!(seed) - prob = SDEProblem(f_mixing!,g_mixing!,u₀,trange,p) - soltsave = collect(trange[1]:dtmix:trange[2]) - sol = solve(prob, EulerHeun(), dt=dtmix, save_noise=true, saveat=soltsave) - - Random.seed!(seed) - proboop = SDEProblem(f_mixing,g_mixing,u₀,trange,p) - soloop = solve(proboop,EulerHeun(), dt=dtmix, save_noise=true, saveat=soltsave) - - @test sol.u ≈ soloop.u atol = 1e-14 - - - # BacksolveAdjoint + Random.seed!(seed) + u₀ = [0.75, 0.5] + p = [-1.5, 0.05, 0.2, 0.01] + dtmix = tend / 1e3 - res_sde_u0, res_sde_p = adjoint_sensitivities(soloop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=BacksolveAdjoint(noisemixing=true)) + # Example from Roessler, SIAM J. NUMER. ANAL, 48, 922–952 with d = 2; m = 2 + function f_mixing!(du, u, p, t) + du[1] = p[1] * u[1] + p[2] * u[2] + du[2] = p[2] * u[1] + p[1] * u[2] + nothing + end + function g_mixing!(du, u, p, t) + du[1] = p[3] * u[1] + p[4] * u[2] + du[2] = p[3] * u[1] + p[4] * u[2] + nothing + end - @test_broken res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=BacksolveAdjoint(noisemixing=true)) + function f_mixing(u, p, t) + dx = p[1] * u[1] + p[2] * u[2] + dy = p[2] * u[1] + p[1] * u[2] + [dx, dy] + end + function g_mixing(u, p, t) + dx = p[3] * u[1] + p[4] * u[2] + dy = p[3] * u[1] + p[4] * u[2] + [dx, dy] + end - @test_broken res_sde_u0 ≈ res_sde_u02 atol = 1e-14 - @test_broken res_sde_p ≈ res_sde_p2 atol = 1e-14 + Random.seed!(seed) + prob = SDEProblem(f_mixing!, g_mixing!, u₀, trange, p) + soltsave = collect(trange[1]:dtmix:trange[2]) + sol = solve(prob, EulerHeun(), dt = dtmix, save_noise = true, saveat = soltsave) - @show res_sde_u0 + Random.seed!(seed) + proboop = SDEProblem(f_mixing, g_mixing, u₀, trange, p) + soloop = solve(proboop, EulerHeun(), dt = dtmix, save_noise = true, saveat = soltsave) - adjproboop = SDEAdjointProblem(soloop,BacksolveAdjoint(autojacvec=ZygoteVJP(),noisemixing=true),tarray,dg!) - adj_soloop = solve(adjproboop,EulerHeun(); dt=dtmix, tstops=soloop.t, adaptive=false) + @test sol.u≈soloop.u atol=1e-14 + # BacksolveAdjoint - @test adj_soloop[end][length(p)+length(u₀)+1:end] == soloop.u[1] - @test adj_soloop[end][1:length(u₀)] == res_sde_u0 - @test adj_soloop[end][length(u₀)+1:end-length(u₀)] == res_sde_p' + res_sde_u0, res_sde_p = adjoint_sensitivities(soloop, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = BacksolveAdjoint(noisemixing = true)) - adjprob = SDEAdjointProblem(sol,BacksolveAdjoint(autojacvec=ReverseDiffVJP(),noisemixing=true,checkpointing=true),tarray,dg!) - adj_sol = solve(adjprob,EulerHeun(); dt=dtmix, adaptive=false,tstops=soloop.t) + @test_broken res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol, EulerHeun(), + t = tarray, + dg_discrete = dg!, + dt = dtmix, + adaptive = false, + sensealg = BacksolveAdjoint(noisemixing = true)) - @test adj_soloop[end] ≈ adj_sol[end] rtol=1e-15 + @test_broken res_sde_u0≈res_sde_u02 atol=1e-14 + @test_broken res_sde_p≈res_sde_p2 atol=1e-14 + @show res_sde_u0 - adjprob = SDEAdjointProblem(sol,BacksolveAdjoint(autojacvec=ReverseDiffVJP(),noisemixing=true,checkpointing=false),tarray,dg!) - adj_sol = solve(adjprob,EulerHeun(); dt=dtmix, adaptive=false,tstops=soloop.t) + adjproboop = SDEAdjointProblem(soloop, + BacksolveAdjoint(autojacvec = ZygoteVJP(), + noisemixing = true), tarray, dg!) + adj_soloop = solve(adjproboop, EulerHeun(); dt = dtmix, tstops = soloop.t, + adaptive = false) - @test adj_soloop[end] ≈ adj_sol[end] rtol=1e-8 + @test adj_soloop[end][(length(p) + length(u₀) + 1):end] == soloop.u[1] + @test adj_soloop[end][1:length(u₀)] == res_sde_u0 + @test adj_soloop[end][(length(u₀) + 1):(end - length(u₀))] == res_sde_p' + adjprob = SDEAdjointProblem(sol, + BacksolveAdjoint(autojacvec = ReverseDiffVJP(), + noisemixing = true, checkpointing = true), + tarray, dg!) + adj_sol = solve(adjprob, EulerHeun(); dt = dtmix, adaptive = false, tstops = soloop.t) - # InterpolatingAdjoint + @test adj_soloop[end]≈adj_sol[end] rtol=1e-15 - res_sde_u0, res_sde_p = adjoint_sensitivities(soloop,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=InterpolatingAdjoint(noisemixing=true)) + adjprob = SDEAdjointProblem(sol, + BacksolveAdjoint(autojacvec = ReverseDiffVJP(), + noisemixing = true, checkpointing = false), + tarray, dg!) + adj_sol = solve(adjprob, EulerHeun(); dt = dtmix, adaptive = false, tstops = soloop.t) + @test adj_soloop[end]≈adj_sol[end] rtol=1e-8 - @test_broken res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dtmix,adaptive=false,sensealg=InterpolatingAdjoint(noisemixing=true)) + # InterpolatingAdjoint - @test_broken res_sde_u0 ≈ res_sde_u02 atol = 1e-8 - @test_broken res_sde_p ≈ res_sde_p2 atol = 5e-8 + res_sde_u0, res_sde_p = adjoint_sensitivities(soloop, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dtmix, adaptive = false, + sensealg = InterpolatingAdjoint(noisemixing = true)) - @show res_sde_u0 + @test_broken res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol, EulerHeun(), + t = tarray, + dg_discrete = dg!, + dt = dtmix, + adaptive = false, + sensealg = InterpolatingAdjoint(noisemixing = true)) - adjproboop = SDEAdjointProblem(soloop,InterpolatingAdjoint(autojacvec=ReverseDiffVJP(),noisemixing=true),tarray,dg!) - adj_soloop = solve(adjproboop,EulerHeun(); dt=dtmix, tstops=soloop.t, adaptive=false) + @test_broken res_sde_u0≈res_sde_u02 atol=1e-8 + @test_broken res_sde_p≈res_sde_p2 atol=5e-8 + @show res_sde_u0 - @test adj_soloop[end][1:length(u₀)] ≈ res_sde_u0 atol = 1e-14 - @test adj_soloop[end][length(u₀)+1:end] ≈ res_sde_p' atol = 1e-14 + adjproboop = SDEAdjointProblem(soloop, + InterpolatingAdjoint(autojacvec = ReverseDiffVJP(), + noisemixing = true), tarray, dg!) + adj_soloop = solve(adjproboop, EulerHeun(); dt = dtmix, tstops = soloop.t, + adaptive = false) - adjprob = SDEAdjointProblem(sol,InterpolatingAdjoint(autojacvec=ReverseDiffVJP(),noisemixing=true,checkpointing=true),tarray,dg!) - adj_sol = solve(adjprob,EulerHeun(); dt=dtmix, adaptive=false,tstops=soloop.t) + @test adj_soloop[end][1:length(u₀)]≈res_sde_u0 atol=1e-14 + @test adj_soloop[end][(length(u₀) + 1):end]≈res_sde_p' atol=1e-14 - @test adj_soloop[end] ≈ adj_sol[end] rtol=1e-8 + adjprob = SDEAdjointProblem(sol, + InterpolatingAdjoint(autojacvec = ReverseDiffVJP(), + noisemixing = true, + checkpointing = true), tarray, dg!) + adj_sol = solve(adjprob, EulerHeun(); dt = dtmix, adaptive = false, tstops = soloop.t) + @test adj_soloop[end]≈adj_sol[end] rtol=1e-8 - adjprob = SDEAdjointProblem(sol, InterpolatingAdjoint(autojacvec=ReverseDiffVJP(),noisemixing=true,checkpointing=false),tarray,dg!) - adj_sol = solve(adjprob,EulerHeun(); dt=dtmix, adaptive=false,tstops=soloop.t) + adjprob = SDEAdjointProblem(sol, + InterpolatingAdjoint(autojacvec = ReverseDiffVJP(), + noisemixing = true, + checkpointing = false), tarray, dg!) + adj_sol = solve(adjprob, EulerHeun(); dt = dtmix, adaptive = false, tstops = soloop.t) - @test adj_soloop[end] ≈ adj_sol[end] rtol=1e-8 + @test adj_soloop[end]≈adj_sol[end] rtol=1e-8 end @testset "mutating non-diagonal noise" begin - a!(du,u,_p,t) = (du .= -u) - a(u,_p,t) = -u + a!(du, u, _p, t) = (du .= -u) + a(u, _p, t) = -u - function b!(du,u,_p,t) - KR, KI = _p[1:2] + function b!(du, u, _p, t) + KR, KI = _p[1:2] - du[1,1] = KR - du[2,1] = KI - end + du[1, 1] = KR + du[2, 1] = KI + end - function b(u,_p,t) - KR, KI = _p[1:2] + function b(u, _p, t) + KR, KI = _p[1:2] - [ KR zero(KR) - KI zero(KR) ] - end + [KR zero(KR) + KI zero(KR)] + end - p = [1.,0.] + p = [1.0, 0.0] - prob! = SDEProblem{true}(a!,b!,[0.,0.],(0.0,0.1),p,noise_rate_prototype=eltype(p).(zeros(2,2))) - prob = SDEProblem{false}(a,b,[0.,0.],(0.0,0.1),p,noise_rate_prototype=eltype(p).(zeros(2,2))) + prob! = SDEProblem{true}(a!, b!, [0.0, 0.0], (0.0, 0.1), p, + noise_rate_prototype = eltype(p).(zeros(2, 2))) + prob = SDEProblem{false}(a, b, [0.0, 0.0], (0.0, 0.1), p, + noise_rate_prototype = eltype(p).(zeros(2, 2))) - function loss(p;SDEprob=prob,sensealg=BacksolveAdjoint()) - _prob = remake(SDEprob,p=p) - sol = solve(_prob, EulerHeun(), dt=1e-5, sensealg=sensealg) - return sum(Array(sol)) - end + function loss(p; SDEprob = prob, sensealg = BacksolveAdjoint()) + _prob = remake(SDEprob, p = p) + sol = solve(_prob, EulerHeun(), dt = 1e-5, sensealg = sensealg) + return sum(Array(sol)) + end - function compute_dp(p, SDEprob, sensealg) - Random.seed!(seed) - Zygote.gradient(p->loss(p,SDEprob=SDEprob,sensealg=sensealg), p)[1] - end + function compute_dp(p, SDEprob, sensealg) + Random.seed!(seed) + Zygote.gradient(p -> loss(p, SDEprob = SDEprob, sensealg = sensealg), p)[1] + end - # test mutating against non-mutating + # test mutating against non-mutating - # non-mutating + # non-mutating - dp1 = compute_dp(p, prob, ForwardDiffSensitivity()) - dp2 = compute_dp(p, prob, BacksolveAdjoint()) - dp3 = compute_dp(p, prob, InterpolatingAdjoint()) + dp1 = compute_dp(p, prob, ForwardDiffSensitivity()) + dp2 = compute_dp(p, prob, BacksolveAdjoint()) + dp3 = compute_dp(p, prob, InterpolatingAdjoint()) - @show dp1 dp2 dp3 + @show dp1 dp2 dp3 - # different vjp choice - _dp2 = compute_dp(p, prob, BacksolveAdjoint(autojacvec=ReverseDiffVJP())) - @test dp2 ≈ _dp2 rtol=1e-8 - _dp3 = compute_dp(p, prob, InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) - @test dp3 ≈ _dp3 rtol=1e-8 + # different vjp choice + _dp2 = compute_dp(p, prob, BacksolveAdjoint(autojacvec = ReverseDiffVJP())) + @test dp2≈_dp2 rtol=1e-8 + _dp3 = compute_dp(p, prob, InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) + @test dp3≈_dp3 rtol=1e-8 - # mutating - _dp1 = compute_dp(p, prob!, ForwardDiffSensitivity()) - _dp2 = compute_dp(p, prob!, BacksolveAdjoint(autojacvec=ReverseDiffVJP())) - _dp3 = compute_dp(p, prob!, InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) - @test_broken _dp4 = compute_dp(p, prob!, InterpolatingAdjoint()) + # mutating + _dp1 = compute_dp(p, prob!, ForwardDiffSensitivity()) + _dp2 = compute_dp(p, prob!, BacksolveAdjoint(autojacvec = ReverseDiffVJP())) + _dp3 = compute_dp(p, prob!, InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) + @test_broken _dp4 = compute_dp(p, prob!, InterpolatingAdjoint()) - @test dp1 ≈ _dp1 rtol=1e-8 - @test dp2 ≈ _dp2 rtol=1e-8 - @test dp3 ≈ _dp3 rtol=1e-8 - @test_broken dp3 ≈ _dp4 rtol=1e-8 + @test dp1≈_dp1 rtol=1e-8 + @test dp2≈_dp2 rtol=1e-8 + @test dp3≈_dp3 rtol=1e-8 + @test_broken dp3≈_dp4 rtol=1e-8 end diff --git a/test/sde_scalar_ito.jl b/test/sde_scalar_ito.jl index 0c6f64e65..6cf1057bf 100644 --- a/test/sde_scalar_ito.jl +++ b/test/sde_scalar_ito.jl @@ -16,28 +16,28 @@ trange = (tstart, tend) t = tstart:0.01:tend tarray = collect(t) -function g(u,p,t) - sum(u.^2.0/2.0) +function g(u, p, t) + sum(u .^ 2.0 / 2.0) end -function dg!(out,u,p,t,i) - (out.=u) +function dg!(out, u, p, t, i) + (out .= u) end -dt = tend/1e4 +dt = tend / 1e4 # non-exploding initialization. -α = 1/(exp(-randn())+1) -β = -α^2 - 1/(exp(-randn())+1) -p = [α,β] +α = 1 / (exp(-randn()) + 1) +β = -α^2 - 1 / (exp(-randn()) + 1) +p = [α, β] -fIto(u,p,t) = p[1]*u #p[1]*u.+p[2]^2/2*u -fStrat(u,p,t) = p[1]*u.-p[2]^2/2*u #p[1]*u -σ(u,p,t) = p[2]*u +fIto(u, p, t) = p[1] * u #p[1]*u.+p[2]^2/2*u +fStrat(u, p, t) = p[1] * u .- p[2]^2 / 2 * u #p[1]*u +σ(u, p, t) = p[2] * u - # Ito sense (Strat sense for commented version) -linear_analytic(u0,p,t,W) = @.(u0*exp(p[1]*t+p[2]*W)) -corfunc(u,p,t) = p[2]^2*u +# Ito sense (Strat sense for commented version) +linear_analytic(u0, p, t, W) = @.(u0*exp(p[1] * t + p[2] * W)) +corfunc(u, p, t) = p[2]^2 * u """ 1D oop @@ -49,61 +49,63 @@ corfunc(u,p,t) = p[2]^2*u # NG = NoiseGrid(Array(tstart:dt:(tend+dt)),[Z for Z in Z1]) # set initial state -u0 = [1/6] +u0 = [1 / 6] # define problem in Ito sense Random.seed!(seed) probIto = SDEProblem(fIto, - σ,u0,trange,p, - #noise=NG - ) + σ, u0, trange, p + #noise=NG + ) # solve Ito sense -solIto = solve(probIto, EM(), dt=dt, adaptive=false, save_noise=true, saveat=dt) - +solIto = solve(probIto, EM(), dt = dt, adaptive = false, save_noise = true, saveat = dt) # define problem in Stratonovich sense Random.seed!(seed) -probStrat = SDEProblem(SDEFunction(fStrat,σ,), - σ,u0,trange,p, - #noise=NG - ) +probStrat = SDEProblem(SDEFunction(fStrat, σ), + σ, u0, trange, p + #noise=NG + ) # solve Strat sense -solStrat = solve(probStrat,RKMil(interpretation=:Stratonovich), dt=dt, - adaptive=false, save_noise=true, saveat=dt) +solStrat = solve(probStrat, RKMil(interpretation = :Stratonovich), dt = dt, + adaptive = false, save_noise = true, saveat = dt) # check that forward solution agrees -@test isapprox(solIto.u, solStrat.u, rtol=1e-3) -@test isapprox(solIto.u, solStrat.u, atol=1e-2) +@test isapprox(solIto.u, solStrat.u, rtol = 1e-3) +@test isapprox(solIto.u, solStrat.u, atol = 1e-2) #@test isapprox(solIto.u, solIto.u_analytic, rtol=1e-3) - """ solve with continuous adjoint sensitivity tools """ # for Ito sense -gs_u0, gs_p = adjoint_sensitivities(solIto,EM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(),corfunc_analytical=corfunc) +gs_u0, gs_p = adjoint_sensitivities(solIto, EM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(), + corfunc_analytical = corfunc) @info gs_u0, gs_p -gs_u0a, gs_pa = adjoint_sensitivities(solIto,EM(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=SciMLSensitivity.ReverseDiffVJP())) +gs_u0a, gs_pa = adjoint_sensitivities(solIto, EM(), t = Array(t), dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = SciMLSensitivity.ReverseDiffVJP())) @info gs_u0a, gs_pa -@test isapprox(gs_u0, gs_u0a, rtol=1e-8) -@test isapprox(gs_p, gs_pa, rtol=1e-8) +@test isapprox(gs_u0, gs_u0a, rtol = 1e-8) +@test isapprox(gs_p, gs_pa, rtol = 1e-8) # for Strat sense -res_u0, res_p = adjoint_sensitivities(solStrat,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dt,adaptive=false,sensealg=BacksolveAdjoint()) +res_u0, res_p = adjoint_sensitivities(solStrat, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dt, adaptive = false, + sensealg = BacksolveAdjoint()) @info res_u0, res_p - """ Tests with respect to analytical result, forward and reverse mode AD """ @@ -111,65 +113,66 @@ Tests with respect to analytical result, forward and reverse mode AD # tests for parameter gradients function Gp(p; sensealg = ReverseDiffAdjoint()) - Random.seed!(seed) - tmp_prob = remake(probStrat,p=p) - _sol = solve(tmp_prob,EulerHeun(),dt=dt,adaptive=false,saveat=Array(t),sensealg=sensealg) - A = convert(Array,_sol) - res = g(A,p,nothing) + Random.seed!(seed) + tmp_prob = remake(probStrat, p = p) + _sol = solve(tmp_prob, EulerHeun(), dt = dt, adaptive = false, saveat = Array(t), + sensealg = sensealg) + A = convert(Array, _sol) + res = g(A, p, nothing) end -res_forward = ForwardDiff.gradient(p -> Gp(p,sensealg=ForwardDiffSensitivity()), p) +res_forward = ForwardDiff.gradient(p -> Gp(p, sensealg = ForwardDiffSensitivity()), p) @info res_forward Wfix = [solStrat.W(t)[1][1] for t in tarray] -resp1 = sum(@. tarray*u0^2*exp(2*(p[1]-p[2]^2/2)*tarray+2*p[2]*Wfix)) -resp2 = sum(@. (Wfix-p[2]*tarray)*u0^2*exp(2*(p[1]-p[2]^2/2)*tarray+2*p[2]*Wfix)) +resp1 = sum(@. tarray * u0^2 * exp(2 * (p[1] - p[2]^2 / 2) * tarray + 2 * p[2] * Wfix)) +resp2 = sum(@. (Wfix - p[2] * tarray) * u0^2 * + exp(2 * (p[1] - p[2]^2 / 2) * tarray + 2 * p[2] * Wfix)) resp = [resp1, resp2] @show resp - -@test isapprox(resp, gs_p', atol=3e-2) # exact vs ito adjoint -@test isapprox(res_p, gs_p, atol=3e-2) # strat vs ito adjoint -@test isapprox(gs_p', res_forward, atol=3e-2) # ito adjoint vs forward -@test isapprox(resp, res_p', rtol=2e-5) # exact vs strat adjoint -@test isapprox(resp, res_forward, rtol=2e-5) # exact vs forward +@test isapprox(resp, gs_p', atol = 3e-2) # exact vs ito adjoint +@test isapprox(res_p, gs_p, atol = 3e-2) # strat vs ito adjoint +@test isapprox(gs_p', res_forward, atol = 3e-2) # ito adjoint vs forward +@test isapprox(resp, res_p', rtol = 2e-5) # exact vs strat adjoint +@test isapprox(resp, res_forward, rtol = 2e-5) # exact vs forward # tests for initial state gradients function Gu0(u0; sensealg = ReverseDiffAdjoint()) - Random.seed!(seed) - tmp_prob = remake(probStrat,u0=u0) - _sol = solve(tmp_prob,EulerHeun(),dt=dt,adaptive=false,saveat=Array(t),sensealg=sensealg) - A = convert(Array,_sol) - res = g(A,p,nothing) + Random.seed!(seed) + tmp_prob = remake(probStrat, u0 = u0) + _sol = solve(tmp_prob, EulerHeun(), dt = dt, adaptive = false, saveat = Array(t), + sensealg = sensealg) + A = convert(Array, _sol) + res = g(A, p, nothing) end -res_forward = ForwardDiff.gradient(u0 -> Gu0(u0,sensealg=ForwardDiffSensitivity()), u0) +res_forward = ForwardDiff.gradient(u0 -> Gu0(u0, sensealg = ForwardDiffSensitivity()), u0) -resu0 = sum(@. u0*exp(2*(p[1]-p[2]^2/2)*tarray+2*p[2]*Wfix)) +resu0 = sum(@. u0 * exp(2 * (p[1] - p[2]^2 / 2) * tarray + 2 * p[2] * Wfix)) @show resu0 -@test isapprox(resu0, gs_u0[1], rtol=5e-2) # exact vs ito adjoint -@test isapprox(res_u0, gs_u0, rtol=5e-2) # strat vs ito adjoint -@test isapprox(gs_u0, res_forward, rtol=5e-2) # ito adjoint vs forward -@test isapprox(resu0, res_u0[1], rtol=1e-3) # exact vs strat adjoint -@test isapprox(res_u0, res_forward, rtol=1e-3) # strat adjoint vs forward -@test isapprox(resu0, res_forward[1], rtol=1e-3) # exact vs forward - +@test isapprox(resu0, gs_u0[1], rtol = 5e-2) # exact vs ito adjoint +@test isapprox(res_u0, gs_u0, rtol = 5e-2) # strat vs ito adjoint +@test isapprox(gs_u0, res_forward, rtol = 5e-2) # ito adjoint vs forward +@test isapprox(resu0, res_u0[1], rtol = 1e-3) # exact vs strat adjoint +@test isapprox(res_u0, res_forward, rtol = 1e-3) # strat adjoint vs forward +@test isapprox(resu0, res_forward[1], rtol = 1e-3) # exact vs forward -adj_probStrat = SDEAdjointProblem(solStrat,BacksolveAdjoint(autojacvec=ZygoteVJP()),t,dg!) -adj_solStrat = solve(adj_probStrat,EulerHeun(), dt=dt) +adj_probStrat = SDEAdjointProblem(solStrat, BacksolveAdjoint(autojacvec = ZygoteVJP()), t, + dg!) +adj_solStrat = solve(adj_probStrat, EulerHeun(), dt = dt) #@show adj_solStrat[end] -adj_probIto = SDEAdjointProblem(solIto,BacksolveAdjoint(autojacvec=ZygoteVJP()),t,dg!, - corfunc_analytical=corfunc) -adj_solIto = solve(adj_probIto,EM(), dt=dt) - -@test isapprox(adj_solStrat[4,:], adj_solIto[4,:], rtol=1e-3) +adj_probIto = SDEAdjointProblem(solIto, BacksolveAdjoint(autojacvec = ZygoteVJP()), t, dg!, + corfunc_analytical = corfunc) +adj_solIto = solve(adj_probIto, EM(), dt = dt) +@test isapprox(adj_solStrat[4, :], adj_solIto[4, :], rtol = 1e-3) # using Plots # pl1 = plot(solStrat, label="Strat forward") diff --git a/test/sde_scalar_stratonovich.jl b/test/sde_scalar_stratonovich.jl index 67481467f..ca06c6fb9 100644 --- a/test/sde_scalar_stratonovich.jl +++ b/test/sde_scalar_stratonovich.jl @@ -14,201 +14,219 @@ trange = (tstart, tend) t = tstart:dt:tend tarray = collect(t) -function g(u,p,t) - sum(u.^2.0/2.0) +function g(u, p, t) + sum(u .^ 2.0 / 2.0) end -function dg!(out,u,p,t,i) - (out.=u) +function dg!(out, u, p, t, i) + (out .= u) end -p2 = [1.01,0.87] - +p2 = [1.01, 0.87] # scalar noise @testset "SDE inplace scalar noise tests" begin - using DiffEqNoiseProcess - - dtscalar = tend/1e3 - - f!(du,u,p,t) = (du .= p[1]*u) - σ!(du,u,p,t) = (du .= p[2]*u) + using DiffEqNoiseProcess - @info "scalar SDE" + dtscalar = tend / 1e3 - Random.seed!(seed) - W = WienerProcess(0.0,0.0,0.0) - u0 = rand(2) + f!(du, u, p, t) = (du .= p[1] * u) + σ!(du, u, p, t) = (du .= p[2] * u) - linear_analytic_strat(u0,p,t,W) = @.(u0*exp(p[1]*t+p[2]*W)) + @info "scalar SDE" - prob = SDEProblem(SDEFunction(f!,σ!,analytic=linear_analytic_strat),σ!,u0,trange,p2, - noise=W - ) - sol = solve(prob,EulerHeun(), dt=dtscalar, save_noise=true) + Random.seed!(seed) + W = WienerProcess(0.0, 0.0, 0.0) + u0 = rand(2) - @test isapprox(sol.u_analytic,sol.u, atol=1e-4) + linear_analytic_strat(u0, p, t, W) = @.(u0*exp(p[1] * t + p[2] * W)) - res_sde_u0, res_sde_p = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtscalar,adaptive=false,sensealg=BacksolveAdjoint()) + prob = SDEProblem(SDEFunction(f!, σ!, analytic = linear_analytic_strat), σ!, u0, trange, + p2, + noise = W) + sol = solve(prob, EulerHeun(), dt = dtscalar, save_noise = true) - @show res_sde_u0, res_sde_p + @test isapprox(sol.u_analytic, sol.u, atol = 1e-4) - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtscalar,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false)) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtscalar, adaptive = false, + sensealg = BacksolveAdjoint()) - @test isapprox(res_sde_u0, res_sde_u02, atol=1e-8) - @test isapprox(res_sde_p, res_sde_p2, atol=1e-8) + @show res_sde_u0, res_sde_p - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtscalar,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP())) + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtscalar, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false)) - @test isapprox(res_sde_u0, res_sde_u02, atol=1e-8) - @test isapprox(res_sde_p, res_sde_p2, atol=1e-8) + @test isapprox(res_sde_u0, res_sde_u02, atol = 1e-8) + @test isapprox(res_sde_p, res_sde_p2, atol = 1e-8) - @show res_sde_u02, res_sde_p2 + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtscalar, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP())) + @test isapprox(res_sde_u0, res_sde_u02, atol = 1e-8) + @test isapprox(res_sde_p, res_sde_p2, atol = 1e-8) - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=tend/1e2,adaptive=false,sensealg=InterpolatingAdjoint()) + @show res_sde_u02, res_sde_p2 + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = tend / 1e2, adaptive = false, + sensealg = InterpolatingAdjoint()) - @test isapprox(res_sde_u0, res_sde_u02, rtol=1e-4) - @test isapprox(res_sde_p, res_sde_p2, rtol=1e-4) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-4) + @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-4) - @show res_sde_u02, res_sde_p2 + @show res_sde_u02, res_sde_p2 - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtscalar,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false)) + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtscalar, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false)) - @test isapprox(res_sde_u0, res_sde_u02, rtol=1e-4) - @test isapprox(res_sde_p, res_sde_p2, rtol=1e-4) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-4) + @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-4) - @show res_sde_u02, res_sde_p2 + @show res_sde_u02, res_sde_p2 - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtscalar,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtscalar, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) - @test isapprox(res_sde_u0, res_sde_u02, rtol=1e-4) - @test isapprox(res_sde_p, res_sde_p2, rtol=1e-4) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-4) + @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-4) - @show res_sde_u02, res_sde_p2 + @show res_sde_u02, res_sde_p2 - function compute_grads(sol, scale=1.0) - _sol = deepcopy(sol) - _sol.W.save_everystep = false - xdis = _sol(tarray) - helpu1 = [u[1] for u in xdis.u] - tmp1 = sum((@. xdis.t*helpu1*helpu1)) + function compute_grads(sol, scale = 1.0) + _sol = deepcopy(sol) + _sol.W.save_everystep = false + xdis = _sol(tarray) + helpu1 = [u[1] for u in xdis.u] + tmp1 = sum((@. xdis.t * helpu1 * helpu1)) - Wtmp = [_sol.W(t)[1][1] for t in tarray] - tmp2 = sum((@. Wtmp*helpu1*helpu1)) + Wtmp = [_sol.W(t)[1][1] for t in tarray] + tmp2 = sum((@. Wtmp * helpu1 * helpu1)) - tmp3 = sum((@. helpu1*helpu1))/helpu1[1] + tmp3 = sum((@. helpu1 * helpu1)) / helpu1[1] - return [tmp3, scale*tmp3], [tmp1*(1.0+scale^2), tmp2*(1.0+scale^2)] - end + return [tmp3, scale * tmp3], [tmp1 * (1.0 + scale^2), tmp2 * (1.0 + scale^2)] + end - true_grads = compute_grads(sol, u0[2]/u0[1]) + true_grads = compute_grads(sol, u0[2] / u0[1]) - @show true_grads + @show true_grads - @test isapprox(res_sde_u0, res_sde_u02, rtol=1e-4) - @test isapprox(res_sde_p, res_sde_p2, rtol=1e-4) - @test isapprox(true_grads[2], res_sde_p', atol=1e-4) - @test isapprox(true_grads[1], res_sde_u0, rtol=1e-4) - @test isapprox(true_grads[2], res_sde_p2', atol=1e-4) - @test isapprox(true_grads[1], res_sde_u02, rtol=1e-4) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-4) + @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-4) + @test isapprox(true_grads[2], res_sde_p', atol = 1e-4) + @test isapprox(true_grads[1], res_sde_u0, rtol = 1e-4) + @test isapprox(true_grads[2], res_sde_p2', atol = 1e-4) + @test isapprox(true_grads[1], res_sde_u02, rtol = 1e-4) end @testset "SDE oop scalar noise tests" begin - using DiffEqNoiseProcess - - dtscalar = tend/1e3 - - f(u,p,t) = p[1]*u - σ(u,p,t) = p[2]*u - - Random.seed!(seed) - W = WienerProcess(0.0,0.0,0.0) - u0 = rand(2) + using DiffEqNoiseProcess - linear_analytic_strat(u0,p,t,W) = @.(u0*exp(p[1]*t+p[2]*W)) + dtscalar = tend / 1e3 - prob = SDEProblem(SDEFunction(f,σ,analytic=linear_analytic_strat),σ,u0,trange,p2, - noise=W - ) - sol = solve(prob,EulerHeun(), dt=dtscalar, save_noise=true) + f(u, p, t) = p[1] * u + σ(u, p, t) = p[2] * u - @test isapprox(sol.u_analytic,sol.u, atol=1e-4) + Random.seed!(seed) + W = WienerProcess(0.0, 0.0, 0.0) + u0 = rand(2) - res_sde_u0, res_sde_p = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtscalar,adaptive=false,sensealg=BacksolveAdjoint()) + linear_analytic_strat(u0, p, t, W) = @.(u0*exp(p[1] * t + p[2] * W)) - @show res_sde_u0, res_sde_p + prob = SDEProblem(SDEFunction(f, σ, analytic = linear_analytic_strat), σ, u0, trange, + p2, + noise = W) + sol = solve(prob, EulerHeun(), dt = dtscalar, save_noise = true) - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtscalar,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false)) + @test isapprox(sol.u_analytic, sol.u, atol = 1e-4) - @test isapprox(res_sde_u0, res_sde_u02, atol=1e-8) - @test isapprox(res_sde_p, res_sde_p2, atol=1e-8) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtscalar, adaptive = false, + sensealg = BacksolveAdjoint()) - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtscalar,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP())) + @show res_sde_u0, res_sde_p - @test isapprox(res_sde_u0, res_sde_u02, atol=1e-8) - @test isapprox(res_sde_p, res_sde_p2, atol=1e-8) + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtscalar, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false)) - @show res_sde_u02, res_sde_p2 + @test isapprox(res_sde_u0, res_sde_u02, atol = 1e-8) + @test isapprox(res_sde_p, res_sde_p2, atol = 1e-8) - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=tend/1e2,adaptive=false,sensealg=InterpolatingAdjoint()) + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtscalar, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP())) + @test isapprox(res_sde_u0, res_sde_u02, atol = 1e-8) + @test isapprox(res_sde_p, res_sde_p2, atol = 1e-8) - @test isapprox(res_sde_u0, res_sde_u02, rtol=1e-4) - @test isapprox(res_sde_p, res_sde_p2, atol=1e-4) + @show res_sde_u02, res_sde_p2 - @show res_sde_u02, res_sde_p2 + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = tend / 1e2, adaptive = false, + sensealg = InterpolatingAdjoint()) - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtscalar,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false)) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-4) + @test isapprox(res_sde_p, res_sde_p2, atol = 1e-4) - @test isapprox(res_sde_u0, res_sde_u02, rtol=1e-4) - @test isapprox(res_sde_p, res_sde_p2, atol=1e-4) + @show res_sde_u02, res_sde_p2 - @show res_sde_u02, res_sde_p2 + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtscalar, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false)) - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol,EulerHeun(),t=Array(t),dg_discrete=dg!, - dt=dtscalar,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-4) + @test isapprox(res_sde_p, res_sde_p2, atol = 1e-4) - @test isapprox(res_sde_u0, res_sde_u02, rtol=1e-4) - @test isapprox(res_sde_p, res_sde_p2, atol=1e-4) + @show res_sde_u02, res_sde_p2 - @show res_sde_u02, res_sde_p2 + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol, EulerHeun(), t = Array(t), + dg_discrete = dg!, + dt = dtscalar, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) - function compute_grads(sol, scale=1.0) - _sol = deepcopy(sol) - _sol.W.save_everystep = false - xdis = _sol(tarray) - helpu1 = [u[1] for u in xdis.u] - tmp1 = sum((@. xdis.t*helpu1*helpu1)) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-4) + @test isapprox(res_sde_p, res_sde_p2, atol = 1e-4) - Wtmp = [_sol.W(t)[1][1] for t in tarray] - tmp2 = sum((@. Wtmp*helpu1*helpu1)) + @show res_sde_u02, res_sde_p2 - tmp3 = sum((@. helpu1*helpu1))/helpu1[1] + function compute_grads(sol, scale = 1.0) + _sol = deepcopy(sol) + _sol.W.save_everystep = false + xdis = _sol(tarray) + helpu1 = [u[1] for u in xdis.u] + tmp1 = sum((@. xdis.t * helpu1 * helpu1)) - return [tmp3, scale*tmp3], [tmp1*(1.0+scale^2), tmp2*(1.0+scale^2)] - end + Wtmp = [_sol.W(t)[1][1] for t in tarray] + tmp2 = sum((@. Wtmp * helpu1 * helpu1)) - true_grads = compute_grads(sol, u0[2]/u0[1]) + tmp3 = sum((@. helpu1 * helpu1)) / helpu1[1] - @show true_grads + return [tmp3, scale * tmp3], [tmp1 * (1.0 + scale^2), tmp2 * (1.0 + scale^2)] + end + true_grads = compute_grads(sol, u0[2] / u0[1]) - @test isapprox(true_grads[2], res_sde_p', atol=1e-4) - @test isapprox(true_grads[1], res_sde_u0, rtol=1e-4) - @test isapprox(true_grads[2], res_sde_p2', atol=1e-4) - @test isapprox(true_grads[1], res_sde_u02, rtol=1e-4) + @show true_grads + @test isapprox(true_grads[2], res_sde_p', atol = 1e-4) + @test isapprox(true_grads[1], res_sde_u0, rtol = 1e-4) + @test isapprox(true_grads[2], res_sde_p2', atol = 1e-4) + @test isapprox(true_grads[1], res_sde_u02, rtol = 1e-4) end diff --git a/test/sde_stratonovich.jl b/test/sde_stratonovich.jl index 60b9823be..7b07c3f39 100644 --- a/test/sde_stratonovich.jl +++ b/test/sde_stratonovich.jl @@ -20,465 +20,539 @@ trange = (tstart, tend) t = tstart:dt:tend tarray = collect(t) -function g(u,p,t) - sum(u.^2.0/2.0) +function g(u, p, t) + sum(u .^ 2.0 / 2.0) end -function dg!(out,u,p,t,i) - (out.=u) +function dg!(out, u, p, t, i) + (out .= u) end -p2 = [1.01,0.87] - +p2 = [1.01, 0.87] @testset "SDE oop Tests (no noise)" begin + f_oop_linear(u, p, t) = p[1] * u + σ_oop_linear(u, p, t) = p[2] * u - f_oop_linear(u,p,t) = p[1]*u - σ_oop_linear(u,p,t) = p[2]*u - - p = [1.01,0.0] - - # generate ODE adjoint results - - prob_oop_ode = ODEProblem(f_oop_linear,u₀,(tstart,tend),p) - sol_oop_ode = solve(prob_oop_ode,Tsit5(),saveat=t,abstol=abstol,reltol=reltol) - res_ode_u0, res_ode_p = adjoint_sensitivities(sol_oop_ode,Tsit5(),t=t,dg_discrete=dg! - ,abstol=abstol,reltol=reltol,sensealg=BacksolveAdjoint()) - - function G(p) - tmp_prob = remake(prob_oop_ode,u0=eltype(p).(prob_oop_ode.u0),p=p, - tspan=eltype(p).(prob_oop_ode.tspan),abstol=abstol, reltol=reltol) - sol = solve(tmp_prob,Tsit5(),saveat=tarray,abstol=abstol, reltol=reltol) - res = g(sol,p,nothing) - end - res_ode_forward = ForwardDiff.gradient(G,p) - - @test isapprox(res_ode_forward[1], sum(@. u₀^2*exp(2*p[1]*t)*t), rtol = 1e-4) - #@test isapprox(res_ode_reverse[1], sum(@. u₀^2*exp(2*p[1]*t)*t), rtol = 1e-4) - @test isapprox(res_ode_p'[1], sum(@. u₀^2*exp(2*p[1]*t)*t), rtol = 1e-4) - #@test isapprox(res_ode_p', res_ode_trackerp, rtol = 1e-4) + p = [1.01, 0.0] - # SDE adjoint results (with noise == 0, so should agree with above) + # generate ODE adjoint results - Random.seed!(seed) - prob_oop_sde = SDEProblem(f_oop_linear,σ_oop_linear,u₀,trange,p) - sol_oop_sde = solve(prob_oop_sde,EulerHeun(),dt=1e-4,adaptive=false,save_noise=true) - res_sde_u0, res_sde_p = adjoint_sensitivities(sol_oop_sde, - EulerHeun(),t=t,dg_discrete=dg!,dt=1e-2,sensealg=BacksolveAdjoint()) + prob_oop_ode = ODEProblem(f_oop_linear, u₀, (tstart, tend), p) + sol_oop_ode = solve(prob_oop_ode, Tsit5(), saveat = t, abstol = abstol, reltol = reltol) + res_ode_u0, res_ode_p = adjoint_sensitivities(sol_oop_ode, Tsit5(), t = t, + dg_discrete = dg!, abstol = abstol, + reltol = reltol, + sensealg = BacksolveAdjoint()) - @info res_sde_p + function G(p) + tmp_prob = remake(prob_oop_ode, u0 = eltype(p).(prob_oop_ode.u0), p = p, + tspan = eltype(p).(prob_oop_ode.tspan), abstol = abstol, + reltol = reltol) + sol = solve(tmp_prob, Tsit5(), saveat = tarray, abstol = abstol, reltol = reltol) + res = g(sol, p, nothing) + end + res_ode_forward = ForwardDiff.gradient(G, p) - res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol_oop_sde, - EulerHeun(),t=t,dg_discrete=dg!,dt=1e-2,sensealg=InterpolatingAdjoint()) + @test isapprox(res_ode_forward[1], sum(@. u₀^2 * exp(2 * p[1] * t) * t), rtol = 1e-4) + #@test isapprox(res_ode_reverse[1], sum(@. u₀^2*exp(2*p[1]*t)*t), rtol = 1e-4) + @test isapprox(res_ode_p'[1], sum(@. u₀^2 * exp(2 * p[1] * t) * t), rtol = 1e-4) + #@test isapprox(res_ode_p', res_ode_trackerp, rtol = 1e-4) - @test isapprox(res_sde_u0, res_sde_u0a, rtol = 1e-6) - @test isapprox(res_sde_p, res_sde_pa, rtol = 1e-6) + # SDE adjoint results (with noise == 0, so should agree with above) - function GSDE1(p) Random.seed!(seed) - tmp_prob = remake(prob_oop_sde,u0=eltype(p).(prob_oop_sde.u0),p=p, - tspan=eltype(p).(prob_oop_sde.tspan)) - sol = solve(tmp_prob,RKMil(interpretation=:Stratonovich),dt=tend/10000,adaptive=false,sensealg=DiffEqBase.SensitivityADPassThrough(),saveat=tarray) - A = convert(Array,sol) - res = g(A,p,nothing) - end - res_sde_forward = ForwardDiff.gradient(GSDE1,p) - - noise = vec((@. sol_oop_sde.W(tarray))) - Wfix = [W[1][1] for W in noise] - @test isapprox(res_sde_forward[1], sum(@. u₀^2*exp(2*p[1]*t)*t), rtol = 1e-4) - @test isapprox(res_sde_p'[1], sum(@. u₀^2*exp(2*p[1]*t)*t), rtol = 1e-4) - @test isapprox(res_sde_p'[2], sum(@. (Wfix)*u₀^2*exp(2*(p[1])*tarray+2*p[2]*Wfix)), rtol = 1e-4) + prob_oop_sde = SDEProblem(f_oop_linear, σ_oop_linear, u₀, trange, p) + sol_oop_sde = solve(prob_oop_sde, EulerHeun(), dt = 1e-4, adaptive = false, + save_noise = true) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol_oop_sde, + EulerHeun(), t = t, dg_discrete = dg!, + dt = 1e-2, sensealg = BacksolveAdjoint()) + + @info res_sde_p + + res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol_oop_sde, + EulerHeun(), t = t, dg_discrete = dg!, + dt = 1e-2, + sensealg = InterpolatingAdjoint()) + + @test isapprox(res_sde_u0, res_sde_u0a, rtol = 1e-6) + @test isapprox(res_sde_p, res_sde_pa, rtol = 1e-6) + + function GSDE1(p) + Random.seed!(seed) + tmp_prob = remake(prob_oop_sde, u0 = eltype(p).(prob_oop_sde.u0), p = p, + tspan = eltype(p).(prob_oop_sde.tspan)) + sol = solve(tmp_prob, RKMil(interpretation = :Stratonovich), dt = tend / 10000, + adaptive = false, sensealg = DiffEqBase.SensitivityADPassThrough(), + saveat = tarray) + A = convert(Array, sol) + res = g(A, p, nothing) + end + res_sde_forward = ForwardDiff.gradient(GSDE1, p) + + noise = vec((@. sol_oop_sde.W(tarray))) + Wfix = [W[1][1] for W in noise] + @test isapprox(res_sde_forward[1], sum(@. u₀^2 * exp(2 * p[1] * t) * t), rtol = 1e-4) + @test isapprox(res_sde_p'[1], sum(@. u₀^2 * exp(2 * p[1] * t) * t), rtol = 1e-4) + @test isapprox(res_sde_p'[2], + sum(@. (Wfix) * u₀^2 * exp(2 * (p[1]) * tarray + 2 * p[2] * Wfix)), + rtol = 1e-4) end @testset "SDE oop Tests (with noise)" begin + f_oop_linear(u, p, t) = p[1] * u + σ_oop_linear(u, p, t) = p[2] * u - f_oop_linear(u,p,t) = p[1]*u - σ_oop_linear(u,p,t) = p[2]*u - - # SDE adjoint results (with noise != 0) - dt1 = tend/1e3 - - Random.seed!(seed) - prob_oop_sde2 = SDEProblem(f_oop_linear,σ_oop_linear,u₀,trange,p2) - sol_oop_sde2 = solve(prob_oop_sde2,RKMil(interpretation=:Stratonovich),dt=dt1,adaptive=false,save_noise=true) - - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint()) - - @info res_sde_p2 - - # test consitency for different switches for the noise Jacobian - res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false)) - - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-6) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-6) + # SDE adjoint results (with noise != 0) + dt1 = tend / 1e3 - @info res_sde_p2a - - res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) - - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-6) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-6) - - @info res_sde_p2a - - res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=tend/dt1,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP())) - - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-6) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-6) - - @info res_sde_p2a - - function GSDE2(p) Random.seed!(seed) - tmp_prob = remake(prob_oop_sde2,u0=eltype(p).(prob_oop_sde2.u0),p=p, - tspan=eltype(p).(prob_oop_sde2.tspan) - #,abstol=abstol, reltol=reltol - ) - sol = solve(tmp_prob,RKMil(interpretation=:Stratonovich),dt=dt1,adaptive=false,sensealg=DiffEqBase.SensitivityADPassThrough(),saveat=tarray) - A = convert(Array,sol) - res = g(A,p,nothing) - end - res_sde_forward2 = ForwardDiff.gradient(GSDE2,p2) - - - Wfix = [sol_oop_sde2.W(t)[1][1] for t in tarray] - resp1 = sum(@. tarray*u₀^2*exp(2*(p2[1])*tarray+2*p2[2]*Wfix)) - resp2 = sum(@. (Wfix)*u₀^2*exp(2*(p2[1])*tarray+2*p2[2]*Wfix)) - resp = [resp1, resp2] - - @test isapprox(res_sde_forward2, resp, rtol = 8e-4) - - @test isapprox(res_sde_p2', res_sde_forward2, rtol = 1e-3) - @test isapprox(res_sde_p2', resp, rtol = 1e-3) - - @info "ForwardDiff" res_sde_forward2 - @info "Exact" resp - @info "BacksolveAdjoint SDE" res_sde_p2 + prob_oop_sde2 = SDEProblem(f_oop_linear, σ_oop_linear, u₀, trange, p2) + sol_oop_sde2 = solve(prob_oop_sde2, RKMil(interpretation = :Stratonovich), dt = dt1, + adaptive = false, save_noise = true) - # InterpolatingAdjoint - @info "InterpolatingAdjoint SDE" + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint()) - res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint()) + @info res_sde_p2 - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-4) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-3) + # test consitency for different switches for the noise Jacobian + res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false)) + + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-6) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-6) + + @info res_sde_p2a + + res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())) + + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-6) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-6) + + @info res_sde_p2a + + res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = tend / dt1, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP())) + + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-6) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-6) + + @info res_sde_p2a + + function GSDE2(p) + Random.seed!(seed) + tmp_prob = remake(prob_oop_sde2, u0 = eltype(p).(prob_oop_sde2.u0), p = p, + tspan = eltype(p).(prob_oop_sde2.tspan) + #,abstol=abstol, reltol=reltol + ) + sol = solve(tmp_prob, RKMil(interpretation = :Stratonovich), dt = dt1, + adaptive = false, sensealg = DiffEqBase.SensitivityADPassThrough(), + saveat = tarray) + A = convert(Array, sol) + res = g(A, p, nothing) + end + res_sde_forward2 = ForwardDiff.gradient(GSDE2, p2) + + Wfix = [sol_oop_sde2.W(t)[1][1] for t in tarray] + resp1 = sum(@. tarray * u₀^2 * exp(2 * (p2[1]) * tarray + 2 * p2[2] * Wfix)) + resp2 = sum(@. (Wfix) * u₀^2 * exp(2 * (p2[1]) * tarray + 2 * p2[2] * Wfix)) + resp = [resp1, resp2] + + @test isapprox(res_sde_forward2, resp, rtol = 8e-4) + + @test isapprox(res_sde_p2', res_sde_forward2, rtol = 1e-3) + @test isapprox(res_sde_p2', resp, rtol = 1e-3) + + @info "ForwardDiff" res_sde_forward2 + @info "Exact" resp + @info "BacksolveAdjoint SDE" res_sde_p2 + + # InterpolatingAdjoint + @info "InterpolatingAdjoint SDE" + + res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint()) + + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-4) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-3) + + @info res_sde_p2a + + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false)) + + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-6) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-6) + + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())) + + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-6) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-6) + + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) - @info res_sde_p2a + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-6) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-6) - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false)) + # Free memory to help Travis - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-6) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-6) + Wfix = nothing + res_sde_forward2 = nothing + res_sde_reverse2 = nothing + resp = nothing + res_sde_trackerp2 = nothing + res_sde_u02 = nothing + sol_oop_sde2 = nothing + res_sde_u02a = nothing + res_sde_p2a = nothing + res_sde_p2 = nothing + sol_oop_sde = nothing + GC.gc() - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP())) + # SDE adjoint results with diagonal noise - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-6) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-6) - - - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) - - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-6) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-6) - - - # Free memory to help Travis - - Wfix = nothing - res_sde_forward2 = nothing - res_sde_reverse2 = nothing - resp = nothing - res_sde_trackerp2 = nothing - res_sde_u02 = nothing - sol_oop_sde2 = nothing - res_sde_u02a = nothing - res_sde_p2a = nothing - res_sde_p2 = nothing - sol_oop_sde = nothing - GC.gc() - - # SDE adjoint results with diagonal noise - - Random.seed!(seed) - prob_oop_sde2 = SDEProblem(f_oop_linear,σ_oop_linear,[u₀;u₀;u₀],trange,p2) - sol_oop_sde2 = solve(prob_oop_sde2,EulerHeun(),dt=dt1,adaptive=false,save_noise=true) - - @info "Diagonal Adjoint" - - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint()) - - res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint()) + Random.seed!(seed) + prob_oop_sde2 = SDEProblem(f_oop_linear, σ_oop_linear, [u₀; u₀; u₀], trange, p2) + sol_oop_sde2 = solve(prob_oop_sde2, EulerHeun(), dt = dt1, adaptive = false, + save_noise = true) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 5e-4) - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 2e-5) + @info "Diagonal Adjoint" - res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false)) + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint()) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 5e-4) - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 2e-5) + res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint()) - res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP())) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 5e-4) + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 2e-5) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 5e-4) - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 2e-5) + res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false)) - res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 5e-4) + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 2e-5) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 5e-4) - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 2e-5) + res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 5e-4) + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 2e-5) - res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false)) + res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-7) - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-7) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 5e-4) + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 2e-5) - res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) + res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false)) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-7) - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-7) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-7) + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-7) - res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP())) + res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())) - @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-7) - @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-7) + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-7) + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-7) - @info res_sde_p2 + res_sde_u02a, res_sde_p2a = adjoint_sensitivities(sol_oop_sde2, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP())) - sol_oop_sde2 = nothing - GC.gc() + @test isapprox(res_sde_p2, res_sde_p2a, rtol = 1e-7) + @test isapprox(res_sde_u02, res_sde_u02a, rtol = 1e-7) + @info res_sde_p2 - @info "Diagonal ForwardDiff" - res_sde_forward2 = ForwardDiff.gradient(GSDE2,p2) + sol_oop_sde2 = nothing + GC.gc() - #@test isapprox(res_sde_forward2, res_sde_reverse2, rtol = 1e-6) - @test isapprox(res_sde_p2', res_sde_forward2, rtol = 1e-3) - #@test isapprox(res_sde_p2', res_sde_reverse2, rtol = 1e-3) + @info "Diagonal ForwardDiff" + res_sde_forward2 = ForwardDiff.gradient(GSDE2, p2) - # u0 - function GSDE3(u) - Random.seed!(seed) - tmp_prob = remake(prob_oop_sde2,u0=u) - sol = solve(tmp_prob,RKMil(interpretation=:Stratonovich),dt=dt1,adaptive=false,saveat=tarray) - A = convert(Array,sol) - res = g(A,nothing,nothing) - end + #@test isapprox(res_sde_forward2, res_sde_reverse2, rtol = 1e-6) + @test isapprox(res_sde_p2', res_sde_forward2, rtol = 1e-3) + #@test isapprox(res_sde_p2', res_sde_reverse2, rtol = 1e-3) - @info "ForwardDiff u0" - res_sde_forward2 = ForwardDiff.gradient(GSDE3,[u₀;u₀;u₀]) + # u0 + function GSDE3(u) + Random.seed!(seed) + tmp_prob = remake(prob_oop_sde2, u0 = u) + sol = solve(tmp_prob, RKMil(interpretation = :Stratonovich), dt = dt1, + adaptive = false, saveat = tarray) + A = convert(Array, sol) + res = g(A, nothing, nothing) + end - @test isapprox(res_sde_u02, res_sde_forward2, rtol = 1e-4) + @info "ForwardDiff u0" + res_sde_forward2 = ForwardDiff.gradient(GSDE3, [u₀; u₀; u₀]) + @test isapprox(res_sde_u02, res_sde_forward2, rtol = 1e-4) end - - ## ## Inplace ## @testset "SDE inplace Tests" begin + f!(du, u, p, t) = du .= p[1] * u + σ!(du, u, p, t) = du .= p[2] * u - f!(du,u,p,t) = du.=p[1]*u - σ!(du,u,p,t) = du.=p[2]*u + dt1 = tend / 1e3 - dt1 = tend/1e3 - - Random.seed!(seed) - prob_sde = SDEProblem(f!,σ!,u₀,trange,p2) - sol_sde = solve(prob_sde,EulerHeun(),dt=dt1,adaptive=false, save_noise=true) - - function GSDE(p) Random.seed!(seed) - tmp_prob = remake(prob_sde,u0=eltype(p).(prob_sde.u0),p=p, - tspan=eltype(p).(prob_sde.tspan)) - sol = solve(tmp_prob,EulerHeun(),dt=dt1,adaptive=false,saveat=tarray) - A = convert(Array,sol) - res = g(A,p,nothing) - end - - res_sde_forward = ForwardDiff.gradient(GSDE,p2) + prob_sde = SDEProblem(f!, σ!, u₀, trange, p2) + sol_sde = solve(prob_sde, EulerHeun(), dt = dt1, adaptive = false, save_noise = true) + function GSDE(p) + Random.seed!(seed) + tmp_prob = remake(prob_sde, u0 = eltype(p).(prob_sde.u0), p = p, + tspan = eltype(p).(prob_sde.tspan)) + sol = solve(tmp_prob, EulerHeun(), dt = dt1, adaptive = false, saveat = tarray) + A = convert(Array, sol) + res = g(A, p, nothing) + end - res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint()) + res_sde_forward = ForwardDiff.gradient(GSDE, p2) - @test isapprox(res_sde_p', res_sde_forward, rtol = 1e-4) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint()) - @info res_sde_p + @test isapprox(res_sde_p', res_sde_forward, rtol = 1e-4) - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false)) + @info res_sde_p - @info res_sde_p2 + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false)) - @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-5) - @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-5) + @info res_sde_p2 - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) + @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-5) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-5) - @info res_sde_p2 + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())) - @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-5) # not broken here because it just uses the vjps - @test isapprox(res_sde_u0 ,res_sde_u02, rtol = 1e-5) + @info res_sde_p2 - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP())) + @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-5) # not broken here because it just uses the vjps + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-5) - @info res_sde_p2 + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP())) - @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-10) - @test isapprox(res_sde_u0 ,res_sde_u02, rtol = 1e-10) + @info res_sde_p2 + @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-10) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-10) - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) - @test isapprox(res_sde_p, res_sde_p2, rtol = 2e-4) - @test isapprox(res_sde_u0 ,res_sde_u02, rtol = 1e-4) + @test isapprox(res_sde_p, res_sde_p2, rtol = 2e-4) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-4) - res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint()) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint()) - @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-7) - @test isapprox(res_sde_u0 ,res_sde_u02, rtol = 1e-7) + @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-7) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-7) - res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false)) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false)) - @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-7) - @test isapprox(res_sde_u0 ,res_sde_u02, rtol = 1e-7) + @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-7) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-7) - res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP())) + res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())) - @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-7) - @test isapprox(res_sde_u0 ,res_sde_u02, rtol = 1e-7) + @test isapprox(res_sde_p, res_sde_p2, rtol = 1e-7) + @test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-7) - # diagonal noise + # diagonal noise - #compare with oop version - f_oop_linear(u,p,t) = p[1]*u - σ_oop_linear(u,p,t) = p[2]*u - Random.seed!(seed) - prob_oop_sde = SDEProblem(f_oop_linear,σ_oop_linear,[u₀;u₀;u₀],trange,p2) - sol_oop_sde = solve(prob_oop_sde,EulerHeun(),dt=dt1,adaptive=false,save_noise=true) - res_oop_u0, res_oop_p = adjoint_sensitivities(sol_oop_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint()) - - @info res_oop_p + #compare with oop version + f_oop_linear(u, p, t) = p[1] * u + σ_oop_linear(u, p, t) = p[2] * u + Random.seed!(seed) + prob_oop_sde = SDEProblem(f_oop_linear, σ_oop_linear, [u₀; u₀; u₀], trange, p2) + sol_oop_sde = solve(prob_oop_sde, EulerHeun(), dt = dt1, adaptive = false, + save_noise = true) + res_oop_u0, res_oop_p = adjoint_sensitivities(sol_oop_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint()) - Random.seed!(seed) - prob_sde = SDEProblem(f!,σ!,[u₀;u₀;u₀],trange,p2) - sol_sde = solve(prob_sde,EulerHeun(),dt=dt1,adaptive=false,save_noise=true) + @info res_oop_p - res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint()) + Random.seed!(seed) + prob_sde = SDEProblem(f!, σ!, [u₀; u₀; u₀], trange, p2) + sol_sde = solve(prob_sde, EulerHeun(), dt = dt1, adaptive = false, save_noise = true) - isapprox(res_sde_p, res_oop_p, rtol = 1e-6) - isapprox(res_sde_u0 ,res_oop_u0, rtol = 1e-6) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint()) - @info res_sde_p + isapprox(res_sde_p, res_oop_p, rtol = 1e-6) + isapprox(res_sde_u0, res_oop_u0, rtol = 1e-6) - res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false)) + @info res_sde_p - @test isapprox(res_sde_p, res_oop_p, rtol = 1e-6) - @test isapprox(res_sde_u0 ,res_oop_u0, rtol = 1e-6) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = false)) - @info res_sde_p + @test isapprox(res_sde_p, res_oop_p, rtol = 1e-6) + @test isapprox(res_sde_u0, res_oop_u0, rtol = 1e-6) - res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) + @info res_sde_p - @test isapprox(res_sde_p, res_oop_p, rtol = 1e-6) - @test isapprox(res_sde_u0 ,res_oop_u0, rtol = 1e-6) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())) - @info res_sde_p + @test isapprox(res_sde_p, res_oop_p, rtol = 1e-6) + @test isapprox(res_sde_u0, res_oop_u0, rtol = 1e-6) - res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP())) + @info res_sde_p - @test isapprox(res_sde_p, res_oop_p, rtol = 1e-6) - @test isapprox(res_sde_u0 ,res_oop_u0, rtol = 1e-6) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP())) - @info res_sde_p + @test isapprox(res_sde_p, res_oop_p, rtol = 1e-6) + @test isapprox(res_sde_u0, res_oop_u0, rtol = 1e-6) + @info res_sde_p - res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) - @test isapprox(res_sde_p, res_oop_p, rtol = 5e-4) - @test isapprox(res_sde_u0 ,res_oop_u0, rtol = 1e-4) + @test isapprox(res_sde_p, res_oop_p, rtol = 5e-4) + @test isapprox(res_sde_u0, res_oop_u0, rtol = 1e-4) - res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint()) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint()) - @test isapprox(res_sde_p, res_oop_p, rtol = 5e-4) - @test isapprox(res_sde_u0 ,res_oop_u0, rtol = 1e-4) + @test isapprox(res_sde_p, res_oop_p, rtol = 5e-4) + @test isapprox(res_sde_u0, res_oop_u0, rtol = 1e-4) - res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=false)) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = false)) - @test isapprox(res_sde_p, res_oop_p, rtol = 5e-4) - @test isapprox(res_sde_u0 ,res_oop_u0, rtol = 1e-4) + @test isapprox(res_sde_p, res_oop_p, rtol = 5e-4) + @test isapprox(res_sde_u0, res_oop_u0, rtol = 1e-4) - res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde,EulerHeun(),t=tarray,dg_discrete=dg!, - dt=dt1,adaptive=false,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP())) + res_sde_u0, res_sde_p = adjoint_sensitivities(sol_sde, EulerHeun(), t = tarray, + dg_discrete = dg!, + dt = dt1, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())) - @test_broken isapprox(res_sde_p, res_oop_p, rtol = 1e-4) - @test isapprox(res_sde_u0 ,res_oop_u0, rtol = 1e-4) + @test_broken isapprox(res_sde_p, res_oop_p, rtol = 1e-4) + @test isapprox(res_sde_u0, res_oop_u0, rtol = 1e-4) end - @testset "SDE oop Tests (Tracker)" begin + f_oop_linear(u, p, t) = p[1] * u + σ_oop_linear(u, p, t) = p[2] * u - f_oop_linear(u,p,t) = p[1]*u - σ_oop_linear(u,p,t) = p[2]*u - - function f_oop_linear(u::Tracker.TrackedArray,p,t) - p[1].*u - end + function f_oop_linear(u::Tracker.TrackedArray, p, t) + p[1] .* u + end - function σ_oop_linear(u::Tracker.TrackedArray,p,t) - p[2].*u - end + function σ_oop_linear(u::Tracker.TrackedArray, p, t) + p[2] .* u + end - Random.seed!(seed) - prob_oop_sde = SDEProblem(f_oop_linear,σ_oop_linear,u₀,trange,p2) + Random.seed!(seed) + prob_oop_sde = SDEProblem(f_oop_linear, σ_oop_linear, u₀, trange, p2) + + function GSDE1(p) + Random.seed!(seed) + tmp_prob = remake(prob_oop_sde, u0 = eltype(p).(prob_oop_sde.u0), p = p, + tspan = eltype(p).(prob_oop_sde.tspan)) + sol = solve(tmp_prob, RKMil(interpretation = :Stratonovich), dt = 5e-4, + adaptive = false, sensealg = DiffEqBase.SensitivityADPassThrough(), + saveat = tarray) + A = convert(Array, sol) + res = g(A, p, nothing) + end + res_sde_forward = ForwardDiff.gradient(GSDE1, p2) - function GSDE1(p) Random.seed!(seed) - tmp_prob = remake(prob_oop_sde,u0=eltype(p).(prob_oop_sde.u0),p=p, - tspan=eltype(p).(prob_oop_sde.tspan)) - sol = solve(tmp_prob,RKMil(interpretation=:Stratonovich),dt=5e-4,adaptive=false,sensealg=DiffEqBase.SensitivityADPassThrough(),saveat=tarray) - A = convert(Array,sol) - res = g(A,p,nothing) - end - res_sde_forward = ForwardDiff.gradient(GSDE1,p2) - - Random.seed!(seed) - res_sde_trackeru0, res_sde_trackerp = Zygote.gradient((u0,p)->sum(Array(solve(prob_oop_sde, - RKMil(interpretation=:Stratonovich),dt=5e-4,adaptive=false,u0=u0,p=p,saveat=tarray, - sensealg=TrackerAdjoint())).^2.0/2.0),u₀,p2) - - @test isapprox(res_sde_forward, res_sde_trackerp, rtol = 1e-5) + res_sde_trackeru0, res_sde_trackerp = Zygote.gradient((u0, p) -> sum(Array(solve(prob_oop_sde, + RKMil(interpretation = :Stratonovich), + dt = 5e-4, + adaptive = false, + u0 = u0, + p = p, + saveat = tarray, + sensealg = TrackerAdjoint())) .^ + 2.0 / 2.0), u₀, p2) + + @test isapprox(res_sde_forward, res_sde_trackerp, rtol = 1e-5) end diff --git a/test/sde_transformation_test.jl b/test/sde_transformation_test.jl index 9650fcf9b..7bcc4e9a6 100644 --- a/test/sde_transformation_test.jl +++ b/test/sde_transformation_test.jl @@ -6,160 +6,166 @@ using Random seed = 100 tspan = (0.0, 0.1) -p = [1.01,0.87] +p = [1.01, 0.87] # scalar -f(u,p,t) = p[1]*u -σ(u,p,t) = p[2]*u +f(u, p, t) = p[1] * u +σ(u, p, t) = p[2] * u Random.seed!(seed) u0 = rand(1) -linear_analytic(u0,p,t,W) = @.(u0*exp((p[1]-p[2]^2/2)*t+p[2]*W)) +linear_analytic(u0, p, t, W) = @.(u0*exp((p[1] - p[2]^2 / 2) * t + p[2] * W)) -prob = SDEProblem(SDEFunction(f,σ,analytic=linear_analytic),σ,u0,tspan,p) -sol = solve(prob,SOSRI(),adaptive=false, dt=0.001, save_noise=true) +prob = SDEProblem(SDEFunction(f, σ, analytic = linear_analytic), σ, u0, tspan, p) +sol = solve(prob, SOSRI(), adaptive = false, dt = 0.001, save_noise = true) -@test isapprox(sol.u_analytic,sol.u, atol=1e-4) +@test isapprox(sol.u_analytic, sol.u, atol = 1e-4) du = zeros(size(u0)) u = sol.u[end] -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g) #transformed_function(du,u,p,tspan[2]) -du2 = transformed_function(u,p,tspan[2]) +du2 = transformed_function(u, p, tspan[2]) #@test du[1] == (p[1]*u[1]-p[2]^2*u[1]) -@test isapprox(du2[1], (p[1]*u[1]-p[2]^2*u[1]), atol=1e-15) +@test isapprox(du2[1], (p[1] * u[1] - p[2]^2 * u[1]), atol = 1e-15) #@test du2 == du -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g,(u,p,t)->p[2]^2*u) -du2 = transformed_function(u,p,tspan[2]) -@test isapprox(du2[1], (p[1]*u[1]-p[2]^2*u[1]), atol=1e-15) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g, + (u, p, t) -> p[2]^2 * u) +du2 = transformed_function(u, p, tspan[2]) +@test isapprox(du2[1], (p[1] * u[1] - p[2]^2 * u[1]), atol = 1e-15) -linear_analytic_strat(u0,p,t,W) = @.(u0*exp((p[1])*t+p[2]*W)) +linear_analytic_strat(u0, p, t, W) = @.(u0*exp((p[1]) * t + p[2] * W)) -prob_strat = SDEProblem{false}(SDEFunction((u,p,t)->p[1]*u-1//2*p[2]^2*u,σ,analytic=linear_analytic_strat),σ,u0,tspan,p) +prob_strat = SDEProblem{false}(SDEFunction((u, p, t) -> p[1] * u - 1 // 2 * p[2]^2 * u, σ, + analytic = linear_analytic_strat), σ, u0, tspan, + p) Random.seed!(seed) -sol_strat = solve(prob_strat,RKMil(interpretation=:Stratonovich),adaptive=false, dt=0.0001, save_noise=true) -prob_strat1 = SDEProblem{false}(SDEFunction((u,p,t)->transformed_function(u,p,t).+1//2*p[2]^2*u[1],σ,analytic=linear_analytic),σ,u0,tspan,p) +sol_strat = solve(prob_strat, RKMil(interpretation = :Stratonovich), adaptive = false, + dt = 0.0001, save_noise = true) +prob_strat1 = SDEProblem{false}(SDEFunction((u, p, t) -> transformed_function(u, p, t) .+ + 1 // 2 * p[2]^2 * u[1], σ, + analytic = linear_analytic), σ, u0, tspan, p) Random.seed!(seed) -sol_strat1 = solve(prob_strat1,RKMil(interpretation=:Stratonovich),adaptive=false, dt=0.0001, save_noise=true) +sol_strat1 = solve(prob_strat1, RKMil(interpretation = :Stratonovich), adaptive = false, + dt = 0.0001, save_noise = true) # Test if we recover Ito solution in Stratonovich sense -@test isapprox(sol_strat.u, sol_strat1.u, atol=1e-4) # own transformation and custom function agree -@test !isapprox(sol_strat.u_analytic,sol_strat.u, atol=1e-4) # we don't get the stratonovich solution for the linear SDE -@test isapprox(sol_strat1.u_analytic,sol_strat.u, atol=1e-3) # we do recover the analytic solution from the Ito sense +@test isapprox(sol_strat.u, sol_strat1.u, atol = 1e-4) # own transformation and custom function agree +@test !isapprox(sol_strat.u_analytic, sol_strat.u, atol = 1e-4) # we don't get the stratonovich solution for the linear SDE +@test isapprox(sol_strat1.u_analytic, sol_strat.u, atol = 1e-3) # we do recover the analytic solution from the Ito sense # inplace -f!(du,u,p,t) = @.(du = p[1]*u) -σ!(du,u,p,t) = @.(du = p[2]*u) +f!(du, u, p, t) = @.(du=p[1] * u) +σ!(du, u, p, t) = @.(du=p[2] * u) -prob = SDEProblem(SDEFunction(f!,σ!,analytic=linear_analytic),σ!,u0,tspan,p) -sol = solve(prob,SOSRI(),adaptive=false, dt=0.001, save_noise=true) +prob = SDEProblem(SDEFunction(f!, σ!, analytic = linear_analytic), σ!, u0, tspan, p) +sol = solve(prob, SOSRI(), adaptive = false, dt = 0.001, save_noise = true) -@test isapprox(sol.u_analytic,sol.u, atol=1e-4) +@test isapprox(sol.u_analytic, sol.u, atol = 1e-4) du = zeros(size(u0)) u = sol.u[end] -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g) -transformed_function(du,u,p,tspan[2]) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g) +transformed_function(du, u, p, tspan[2]) -@test isapprox(du[1], (p[1]*u[1]-p[2]^2*u[1]), atol=1e-15) +@test isapprox(du[1], (p[1] * u[1] - p[2]^2 * u[1]), atol = 1e-15) # @test isapprox(du2[1], (p[1]*u[1]-p[2]^2*u[1]), atol=1e-15) # @test isapprox(du2, du, atol=1e-15) -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g,(du,u,p,t)-> (du.=p[2]^2*u)) -transformed_function(du,u,p,tspan[2]) -@test du[1] == (p[1]*u[1]-p[2]^2*u[1]) - +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g, + (du, u, p, t) -> (du .= p[2]^2 * u)) +transformed_function(du, u, p, tspan[2]) +@test du[1] == (p[1] * u[1] - p[2]^2 * u[1]) # diagonal noise u0 = rand(3) -prob = SDEProblem(SDEFunction(f,σ,analytic=linear_analytic),σ,u0,tspan,p) -sol = solve(prob,SOSRI(),adaptive=false, dt=0.001, save_noise=true) +prob = SDEProblem(SDEFunction(f, σ, analytic = linear_analytic), σ, u0, tspan, p) +sol = solve(prob, SOSRI(), adaptive = false, dt = 0.001, save_noise = true) u = sol.u[end] -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g) -du2 = transformed_function(u,p,tspan[2]) -@test isapprox(du2,(p[1]*u-p[2]^2*u), atol=1e-15) - -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g,(u,p,t)->p[2]^2*u) -du2 = transformed_function(u,p,tspan[2]) -@test du2[1] == (p[1]*u[1]-p[2]^2*u[1]) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g) +du2 = transformed_function(u, p, tspan[2]) +@test isapprox(du2, (p[1] * u - p[2]^2 * u), atol = 1e-15) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g, + (u, p, t) -> p[2]^2 * u) +du2 = transformed_function(u, p, tspan[2]) +@test du2[1] == (p[1] * u[1] - p[2]^2 * u[1]) -prob = SDEProblem(SDEFunction(f!,σ!,analytic=linear_analytic),σ!,u0,tspan,p) -sol = solve(prob,SOSRI(),adaptive=false, dt=0.001, save_noise=true) +prob = SDEProblem(SDEFunction(f!, σ!, analytic = linear_analytic), σ!, u0, tspan, p) +sol = solve(prob, SOSRI(), adaptive = false, dt = 0.001, save_noise = true) du = zeros(size(u0)) u = sol.u[end] -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g) -transformed_function(du,u,p,tspan[2]) -@test isapprox(du,(p[1]*u-p[2]^2*u), atol=1e-15) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g) +transformed_function(du, u, p, tspan[2]) +@test isapprox(du, (p[1] * u - p[2]^2 * u), atol = 1e-15) -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g,(du,u,p,t)-> (du.=p[2]^2*u)) -transformed_function(du,u,p,tspan[2]) -@test isapprox(du,(p[1]*u-p[2]^2*u), atol=1e-15) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g, + (du, u, p, t) -> (du .= p[2]^2 * u)) +transformed_function(du, u, p, tspan[2]) +@test isapprox(du, (p[1] * u - p[2]^2 * u), atol = 1e-15) # non-diagonal noise torus u0 = rand(2) -p = rand(1) +p = rand(1) -fnd(u,p,t) = 0*u -function σnd(u,p,t) - du = [cos(p[1])*sin(u[1]) cos(p[1])*cos(u[1]) -sin(p[1])*sin(u[2]) -sin(p[1])*cos(u[2]) - sin(p[1])*sin(u[1]) sin(p[1])*cos(u[1]) cos(p[1])*sin(u[2]) cos(p[1])*cos(u[2]) ] - return du +fnd(u, p, t) = 0 * u +function σnd(u, p, t) + du = [cos(p[1])*sin(u[1]) cos(p[1])*cos(u[1]) -sin(p[1])*sin(u[2]) -sin(p[1])*cos(u[2]) + sin(p[1])*sin(u[1]) sin(p[1])*cos(u[1]) cos(p[1])*sin(u[2]) cos(p[1])*cos(u[2])] + return du end -prob = SDEProblem(fnd,σnd,u0,tspan,p,noise_rate_prototype=zeros(2,4)) -sol = solve(prob,EM(),adaptive=false, dt=0.001, save_noise=true) - - -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g) -du2 = transformed_function(u0,p,tspan[2]) -@test isapprox(du2,zeros(2), atol=1e-15) - -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g,(u,p,t)->false*u) -du2 = transformed_function(u0,p,tspan[2]) -@test isapprox(du2,zeros(2), atol=1e-15) - - -fnd!(du,u,p,t) = du .= false -function σnd!(du,u,p,t) - du[1,1] = cos(p[1])*sin(u[1]) - du[1,2] = cos(p[1])*cos(u[1]) - du[1,3] = -sin(p[1])*sin(u[2]) - du[1,4] = -sin(p[1])*cos(u[2]) - du[2,1] = sin(p[1])*sin(u[1]) - du[2,2] = sin(p[1])*cos(u[1]) - du[2,3] = cos(p[1])*sin(u[2]) - du[2,4] = cos(p[1])*cos(u[2]) - return nothing +prob = SDEProblem(fnd, σnd, u0, tspan, p, noise_rate_prototype = zeros(2, 4)) +sol = solve(prob, EM(), adaptive = false, dt = 0.001, save_noise = true) + +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g) +du2 = transformed_function(u0, p, tspan[2]) +@test isapprox(du2, zeros(2), atol = 1e-15) + +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g, + (u, p, t) -> false * u) +du2 = transformed_function(u0, p, tspan[2]) +@test isapprox(du2, zeros(2), atol = 1e-15) + +fnd!(du, u, p, t) = du .= false +function σnd!(du, u, p, t) + du[1, 1] = cos(p[1]) * sin(u[1]) + du[1, 2] = cos(p[1]) * cos(u[1]) + du[1, 3] = -sin(p[1]) * sin(u[2]) + du[1, 4] = -sin(p[1]) * cos(u[2]) + du[2, 1] = sin(p[1]) * sin(u[1]) + du[2, 2] = sin(p[1]) * cos(u[1]) + du[2, 3] = cos(p[1]) * sin(u[2]) + du[2, 4] = cos(p[1]) * cos(u[2]) + return nothing end -prob = SDEProblem(fnd!,σnd!,u0,tspan,p,noise_rate_prototype=zeros(2,4)) -sol = solve(prob,EM(),adaptive=false, dt=0.001, save_noise=true) +prob = SDEProblem(fnd!, σnd!, u0, tspan, p, noise_rate_prototype = zeros(2, 4)) +sol = solve(prob, EM(), adaptive = false, dt = 0.001, save_noise = true) du = zeros(size(u0)) u = sol.u[end] -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g) -transformed_function(du,u,p,tspan[2]) -@test isapprox(du,zeros(2), atol=1e-15) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g) +transformed_function(du, u, p, tspan[2]) +@test isapprox(du, zeros(2), atol = 1e-15) -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g,(du,u,p,t)-> (du.=false*u)) -transformed_function(du,u,p,tspan[2]) -@test isapprox(du,zeros(2), atol=1e-15) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g, + (du, u, p, t) -> (du .= false * u)) +transformed_function(du, u, p, tspan[2]) +@test isapprox(du, zeros(2), atol = 1e-15) t = sol.t[end] - """ Check compatibility of StochasticTransformedFunction with vjp for adjoints """ - ### # Check general compatibility of StochasticTransformedFunction() with Zygote ### @@ -173,34 +179,33 @@ p = rand(2) λ = rand(1) _dy, back = Zygote.pullback(u0, p) do u, p - vec(f(u, p, t)-p[2]^2*u) + vec(f(u, p, t) - p[2]^2 * u) end -∇1,∇2 = back(λ) +∇1, ∇2 = back(λ) -@test isapprox(∇1, (p[1]-p[2]^2)*λ, atol=1e-15) -@test isapprox(∇2, (@. [1,-2*p[2]]*u0*λ[1]), atol=1e-15) - -prob = SDEProblem(f,σ,u0,tspan,p) -sol = solve(prob,SOSRI(),adaptive=false, dt=0.001, save_noise=true) -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g) +@test isapprox(∇1, (p[1] - p[2]^2) * λ, atol = 1e-15) +@test isapprox(∇2, (@. [1, -2 * p[2]] * u0 * λ[1]), atol = 1e-15) +prob = SDEProblem(f, σ, u0, tspan, p) +sol = solve(prob, SOSRI(), adaptive = false, dt = 0.001, save_noise = true) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g) # Zygote doesn't allow nesting _dy, back = Zygote.pullback(u0, p) do u, p - vec(transformed_function(u, p, t)) + vec(transformed_function(u, p, t)) end @test_broken back(λ) # @test isapprox(∇1, (p[1]-p[2]^2)*λ, atol=1e-15) # @test isapprox(∇2, (@. [1,-2*p[2]]*u0*λ[1]), atol=1e-15) -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g, (u,p,t)->p[2]^2*u) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g, + (u, p, t) -> p[2]^2 * u) _dy, back = Zygote.pullback(u0, p) do u, p - vec(transformed_function(u, p, t)) + vec(transformed_function(u, p, t)) end -∇1,∇2 = back(λ) -@test isapprox(∇1, (p[1]-p[2]^2)*λ, atol=1e-15) -@test isapprox(∇2, (@. [1,-2*p[2]]*u0*λ[1]), atol=1e-15) - +∇1, ∇2 = back(λ) +@test isapprox(∇1, (p[1] - p[2]^2) * λ, atol = 1e-15) +@test isapprox(∇2, (@. [1, -2 * p[2]] * u0 * λ[1]), atol = 1e-15) ### # Check general compatibility of StochasticTransformedFunction() with ReverseDiff @@ -210,8 +215,8 @@ using ReverseDiff # scalar -tape = ReverseDiff.GradientTape((u0, p, [t])) do u,p,t - vec(f(u, p, t)-p[2]^2*u) +tape = ReverseDiff.GradientTape((u0, p, [t])) do u, p, t + vec(f(u, p, t) - p[2]^2 * u) end tu, tp, tt = ReverseDiff.input_hook(tape) @@ -228,16 +233,15 @@ ReverseDiff.forward_pass!(tape) ReverseDiff.increment_deriv!(output, λ) ReverseDiff.reverse_pass!(tape) -@test isapprox(ReverseDiff.deriv(tu), (p[1]-p[2]^2)*λ, atol=1e-15) -@test isapprox(ReverseDiff.deriv(tp), (@. [1,-2*p[2]]*u0*λ[1]), atol=1e-15) - +@test isapprox(ReverseDiff.deriv(tu), (p[1] - p[2]^2) * λ, atol = 1e-15) +@test isapprox(ReverseDiff.deriv(tp), (@. [1, -2 * p[2]] * u0 * λ[1]), atol = 1e-15) -tape = ReverseDiff.GradientTape((u0, p, [t])) do u,p,t - _dy, back = Zygote.pullback(u, p) do u, p - vec(σ(u, p, t)) - end - tmp1,tmp2 = back(_dy) - return f(u, p, t) - vec(tmp1) +tape = ReverseDiff.GradientTape((u0, p, [t])) do u, p, t + _dy, back = Zygote.pullback(u, p) do u, p + vec(σ(u, p, t)) + end + tmp1, tmp2 = back(_dy) + return f(u, p, t) - vec(tmp1) end tu, tp, tt = ReverseDiff.input_hook(tape) @@ -254,13 +258,13 @@ ReverseDiff.forward_pass!(tape) ReverseDiff.increment_deriv!(output, λ) ReverseDiff.reverse_pass!(tape) -@test isapprox(ReverseDiff.deriv(tu), (p[1]-p[2]^2)*λ, atol=1e-15) -@test isapprox(ReverseDiff.deriv(tp), (@. [1,-2*p[2]]*u0*λ[1]), atol=1e-15) +@test isapprox(ReverseDiff.deriv(tu), (p[1] - p[2]^2) * λ, atol = 1e-15) +@test isapprox(ReverseDiff.deriv(tp), (@. [1, -2 * p[2]] * u0 * λ[1]), atol = 1e-15) -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g) -tape = ReverseDiff.GradientTape((u0, p, [t])) do u,p,t - vec(transformed_function(u, p, first(t))) +tape = ReverseDiff.GradientTape((u0, p, [t])) do u, p, t + vec(transformed_function(u, p, first(t))) end tu, tp, tt = ReverseDiff.input_hook(tape) @@ -277,9 +281,8 @@ ReverseDiff.forward_pass!(tape) ReverseDiff.increment_deriv!(output, λ) ReverseDiff.reverse_pass!(tape) -@test isapprox(ReverseDiff.deriv(tu), (p[1]-p[2]^2)*λ, atol=1e-15) -@test isapprox(ReverseDiff.deriv(tp), (@. [1,-2*p[2]]*u0*λ[1]), atol=1e-15) - +@test isapprox(ReverseDiff.deriv(tu), (p[1] - p[2]^2) * λ, atol = 1e-15) +@test isapprox(ReverseDiff.deriv(tp), (@. [1, -2 * p[2]] * u0 * λ[1]), atol = 1e-15) # diagonal Random.seed!(seed) @@ -287,17 +290,16 @@ u0 = rand(3) λ = rand(3) _dy, back = Zygote.pullback(u0, p) do u, p - vec(f(u, p, t)-p[2]^2*u) + vec(f(u, p, t) - p[2]^2 * u) end -∇1,∇2 = back(λ) - -@test isapprox(∇1, (p[1]-p[2]^2)*λ, atol=1e-15) -@test isapprox(∇2[1], dot(u0,λ), atol=1e-15) -@test isapprox(∇2[2], -2*p[2]*dot(u0,λ), atol=1e-15) +∇1, ∇2 = back(λ) +@test isapprox(∇1, (p[1] - p[2]^2) * λ, atol = 1e-15) +@test isapprox(∇2[1], dot(u0, λ), atol = 1e-15) +@test isapprox(∇2[2], -2 * p[2] * dot(u0, λ), atol = 1e-15) -tape = ReverseDiff.GradientTape((u0, p, [t])) do u,p,t - vec(transformed_function(u, p, first(t))) +tape = ReverseDiff.GradientTape((u0, p, [t])) do u, p, t + vec(transformed_function(u, p, first(t))) end tu, tp, tt = ReverseDiff.input_hook(tape) @@ -316,31 +318,29 @@ ReverseDiff.reverse_pass!(tape) tmptp = ReverseDiff.deriv(tp) -@test isapprox(ReverseDiff.deriv(tu), (p[1]-p[2]^2)*λ, atol=1e-15) -@test isapprox(tmptp[1], dot(u0,λ), atol=1e-15) -@test isapprox(tmptp[2], -2*p[2]*dot(u0,λ), atol=1e-15) - +@test isapprox(ReverseDiff.deriv(tu), (p[1] - p[2]^2) * λ, atol = 1e-15) +@test isapprox(tmptp[1], dot(u0, λ), atol = 1e-15) +@test isapprox(tmptp[2], -2 * p[2] * dot(u0, λ), atol = 1e-15) # non-diagonal Random.seed!(seed) u0 = rand(2) -p = rand(1) -λ = rand(2) +p = rand(1) +λ = rand(2) _dy, back = Zygote.pullback(u0, p) do u, p - vec(fnd(u, p, t)) + vec(fnd(u, p, t)) end -∇1,∇2 = back(λ) +∇1, ∇2 = back(λ) -@test isapprox(∇1, zero(∇1), atol=1e-15) +@test isapprox(∇1, zero(∇1), atol = 1e-15) +prob = SDEProblem(fnd, σnd, u0, tspan, p, noise_rate_prototype = zeros(2, 4)) +sol = solve(prob, EM(), adaptive = false, dt = 0.001, save_noise = true) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g) -prob = SDEProblem(fnd,σnd,u0,tspan,p,noise_rate_prototype=zeros(2,4)) -sol = solve(prob,EM(),adaptive=false, dt=0.001, save_noise=true) -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g) - -tape = ReverseDiff.GradientTape((u0, p, [t])) do u,p,t - vec(transformed_function(u, p, first(t))) +tape = ReverseDiff.GradientTape((u0, p, [t])) do u, p, t + vec(transformed_function(u, p, first(t))) end tu, tp, tt = ReverseDiff.input_hook(tape) @@ -357,9 +357,8 @@ ReverseDiff.forward_pass!(tape) ReverseDiff.increment_deriv!(output, λ) ReverseDiff.reverse_pass!(tape) -@test isapprox(ReverseDiff.deriv(tu), zero(u0), atol=1e-15) -@test isapprox(ReverseDiff.deriv(tp), zero(p), atol=1e-15) - +@test isapprox(ReverseDiff.deriv(tu), zero(u0), atol = 1e-15) +@test isapprox(ReverseDiff.deriv(tp), zero(p), atol = 1e-15) ### # Check Mutating functions @@ -370,25 +369,24 @@ u0 = rand(1) p = rand(2) λ = rand(1) -prob = SDEProblem(SDEFunction(f!,σ!,analytic=linear_analytic),σ!,u0,tspan,p) -sol = solve(prob,SOSRI(),adaptive=false, dt=0.001, save_noise=true) +prob = SDEProblem(SDEFunction(f!, σ!, analytic = linear_analytic), σ!, u0, tspan, p) +sol = solve(prob, SOSRI(), adaptive = false, dt = 0.001, save_noise = true) du = zeros(size(u0)) u = sol.u[end] -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g) - +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g) -function inplacefunc!(du,u,p,t) - du .= p[2]^2*u - return nothing +function inplacefunc!(du, u, p, t) + du .= p[2]^2 * u + return nothing end -tape = ReverseDiff.GradientTape((u0, p, [t])) do u,p,t - du1 = similar(u, size(u)) - du2 = similar(u, size(u)) - f!(du1, u, p, first(t)) - inplacefunc!(du2, u, p, first(t)) - return vec(du1-du2) +tape = ReverseDiff.GradientTape((u0, p, [t])) do u, p, t + du1 = similar(u, size(u)) + du2 = similar(u, size(u)) + f!(du1, u, p, first(t)) + inplacefunc!(du2, u, p, first(t)) + return vec(du1 - du2) end tu, tp, tt = ReverseDiff.input_hook(tape) # u0 @@ -405,21 +403,19 @@ ReverseDiff.forward_pass!(tape) ReverseDiff.increment_deriv!(output, λ) ReverseDiff.reverse_pass!(tape) -@test isapprox(ReverseDiff.deriv(tu), (p[1]-p[2]^2)*λ, atol=1e-15) # -0.016562475307537294 -@test isapprox(ReverseDiff.deriv(tp), (@. [1,-2*p[2]]*u0*λ[1]), atol=1e-15) #[0.017478629739736098, -0.023103635221731166] - - - -tape = ReverseDiff.GradientTape((u0, p, [t])) do u,p,t - _dy, back = Zygote.pullback(u, p) do u, p - out_ = Zygote.Buffer(similar(u)) - σ!(out_, u, p, t) - vec(copy(out_)) - end - tmp1,tmp2 = back(λ) - du1 = similar(u, size(u)) - f!(du1, u, p, first(t)) - return vec(du1-tmp1) +@test isapprox(ReverseDiff.deriv(tu), (p[1] - p[2]^2) * λ, atol = 1e-15) # -0.016562475307537294 +@test isapprox(ReverseDiff.deriv(tp), (@. [1, -2 * p[2]] * u0 * λ[1]), atol = 1e-15) #[0.017478629739736098, -0.023103635221731166] + +tape = ReverseDiff.GradientTape((u0, p, [t])) do u, p, t + _dy, back = Zygote.pullback(u, p) do u, p + out_ = Zygote.Buffer(similar(u)) + σ!(out_, u, p, t) + vec(copy(out_)) + end + tmp1, tmp2 = back(λ) + du1 = similar(u, size(u)) + f!(du1, u, p, first(t)) + return vec(du1 - tmp1) end tu, tp, tt = ReverseDiff.input_hook(tape) @@ -436,14 +432,13 @@ ReverseDiff.forward_pass!(tape) ReverseDiff.increment_deriv!(output, λ) ReverseDiff.reverse_pass!(tape) -@test_broken isapprox(ReverseDiff.deriv(tu), (p[1]-p[2]^2)*λ, atol=1e-15) -@test_broken isapprox(ReverseDiff.deriv(tp), (@. [1,-2*p[2]]*u0*λ[1]), atol=1e-15) +@test_broken isapprox(ReverseDiff.deriv(tu), (p[1] - p[2]^2) * λ, atol = 1e-15) +@test_broken isapprox(ReverseDiff.deriv(tp), (@. [1, -2 * p[2]] * u0 * λ[1]), atol = 1e-15) - -tape = ReverseDiff.GradientTape((u0, p, [t])) do u1,p1,t1 - du1 = similar(u1, size(u1)) - transformed_function(du1, u1, p1, first(t1)) - return vec(du1) +tape = ReverseDiff.GradientTape((u0, p, [t])) do u1, p1, t1 + du1 = similar(u1, size(u1)) + transformed_function(du1, u1, p1, first(t1)) + return vec(du1) end tu, tp, tt = ReverseDiff.input_hook(tape) # p[1]*u0 @@ -460,16 +455,16 @@ ReverseDiff.forward_pass!(tape) ReverseDiff.increment_deriv!(output, λ) ReverseDiff.reverse_pass!(tape) -@test isapprox(ReverseDiff.deriv(tu), (p[1]-p[2]^2)*λ, atol=1e-15) -@test isapprox(ReverseDiff.deriv(tp), (@. [1,-2*p[2]]*u0*λ[1]), atol=1e-15) - +@test isapprox(ReverseDiff.deriv(tu), (p[1] - p[2]^2) * λ, atol = 1e-15) +@test isapprox(ReverseDiff.deriv(tp), (@. [1, -2 * p[2]] * u0 * λ[1]), atol = 1e-15) -transformed_function = StochasticTransformedFunction(sol,sol.prob.f,sol.prob.g,(du,u,p,t)-> (du.=p[2]^2*u)) +transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g, + (du, u, p, t) -> (du .= p[2]^2 * u)) -tape = ReverseDiff.GradientTape((u0, p, [t])) do u1,p1,t1 - du1 = similar(u1, size(u1)) - transformed_function(du1, u1, p1, first(t1)) - return vec(du1) +tape = ReverseDiff.GradientTape((u0, p, [t])) do u1, p1, t1 + du1 = similar(u1, size(u1)) + transformed_function(du1, u1, p1, first(t1)) + return vec(du1) end tu, tp, tt = ReverseDiff.input_hook(tape) # p[1]*u0 @@ -486,5 +481,5 @@ ReverseDiff.forward_pass!(tape) ReverseDiff.increment_deriv!(output, λ) ReverseDiff.reverse_pass!(tape) -@test isapprox(ReverseDiff.deriv(tu), (p[1]-p[2]^2)*λ, atol=1e-15) -@test isapprox(ReverseDiff.deriv(tp), (@. [1,-2*p[2]]*u0*λ[1]), atol=1e-15) +@test isapprox(ReverseDiff.deriv(tu), (p[1] - p[2]^2) * λ, atol = 1e-15) +@test isapprox(ReverseDiff.deriv(tp), (@. [1, -2 * p[2]] * u0 * λ[1]), atol = 1e-15) diff --git a/test/second_order.jl b/test/second_order.jl index df58906f6..d11ff35b6 100644 --- a/test/second_order.jl +++ b/test/second_order.jl @@ -1,31 +1,36 @@ -using SciMLSensitivity, OrdinaryDiffEq, DiffEqBase, ForwardDiff -using Test - -function fb(du,u,p,t) - du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] - du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] -end - -function jac(J,u,p,t) - (x, y, a, b, c) = (u[1], u[2], p[1], p[2], p[3]) - J[1,1] = a + y * b * -1 - J[2,1] = y - J[1,2] = b * x * -1 - J[2,2] = c * -1 + x -end - -f = ODEFunction(fb,jac=jac) -p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0] -prob = ODEProblem(f,u0,(0.0,10.0),p) -loss(sol) = sum(sol) -v = ones(4) - -H = second_order_sensitivities(loss,prob,Vern9(),saveat=0.1,abstol=1e-12,reltol=1e-12) -Hv = second_order_sensitivity_product(loss,v,prob,Vern9(),saveat=0.1,abstol=1e-12,reltol=1e-12) - -_loss(p) = loss(solve(prob,Vern9();u0=u0,p=p,saveat=0.1,abstol=1e-12,reltol=1e-12)) -H2 = ForwardDiff.hessian(_loss,p) -H2v = H*v - -@test H ≈ H2 -@test Hv ≈ H2v +using SciMLSensitivity, OrdinaryDiffEq, DiffEqBase, ForwardDiff +using Test + +function fb(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2] +end + +function jac(J, u, p, t) + (x, y, a, b, c) = (u[1], u[2], p[1], p[2], p[3]) + J[1, 1] = a + y * b * -1 + J[2, 1] = y + J[1, 2] = b * x * -1 + J[2, 2] = c * -1 + x +end + +f = ODEFunction(fb, jac = jac) +p = [1.5, 1.0, 3.0, 1.0]; +u0 = [1.0; 1.0]; +prob = ODEProblem(f, u0, (0.0, 10.0), p) +loss(sol) = sum(sol) +v = ones(4) + +H = second_order_sensitivities(loss, prob, Vern9(), saveat = 0.1, abstol = 1e-12, + reltol = 1e-12) +Hv = second_order_sensitivity_product(loss, v, prob, Vern9(), saveat = 0.1, abstol = 1e-12, + reltol = 1e-12) + +function _loss(p) + loss(solve(prob, Vern9(); u0 = u0, p = p, saveat = 0.1, abstol = 1e-12, reltol = 1e-12)) +end +H2 = ForwardDiff.hessian(_loss, p) +H2v = H * v + +@test H ≈ H2 +@test Hv ≈ H2v diff --git a/test/second_order_odes.jl b/test/second_order_odes.jl index 910f3e5ee..004c38e6f 100644 --- a/test/second_order_odes.jl +++ b/test/second_order_odes.jl @@ -1,22 +1,42 @@ -using OrdinaryDiffEq, SciMLSensitivity, Zygote, RecursiveArrayTools, Test - -u0 = Float32[1.; 2.] -du0 = Float32[0.; 2.] -tspan = (0.0f0, 1.0f0) -t = range(tspan[1], tspan[2], length=20) -p = Float32[1.01,0.9] -ff(du,u,p,t) = -p.*u -prob = SecondOrderODEProblem{false}(ff, du0, u0, tspan, p) -ddu01, du01, dp1 = Zygote.gradient((du0,u0,p)->sum(Array(solve(prob, Tsit5(), u0=ArrayPartition(du0,u0), p=p, saveat=t, sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())))),du0,u0,p) -ddu02, du02, dp2 = Zygote.gradient((du0,u0,p)->sum(Array(solve(prob, Tsit5(), u0=ArrayPartition(du0,u0), p=p, saveat=t, sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP())))),du0,u0,p) -ddu03, du03, dp3 = Zygote.gradient((du0,u0,p)->sum(Array(solve(prob, Tsit5(), u0=ArrayPartition(du0,u0), p=p, saveat=t, sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP())))),du0,u0,p) -ddu04, du04, dp4 = Zygote.gradient((du0,u0,p)->sum(Array(solve(prob, Tsit5(), u0=ArrayPartition(du0,u0), p=p, saveat=t, sensealg = ForwardDiffSensitivity()))),du0,u0,p) -@test ddu01 ≈ ddu02 -@test ddu01 ≈ ddu03 -@test ddu01 ≈ ddu04 -@test du01 ≈ du02 -@test du01 ≈ du03 -@test du01 ≈ du04 -@test dp1 ≈ dp2 -@test dp1 ≈ dp3 -@test dp1 ≈ dp4 +using OrdinaryDiffEq, SciMLSensitivity, Zygote, RecursiveArrayTools, Test + +u0 = Float32[1.0; 2.0] +du0 = Float32[0.0; 2.0] +tspan = (0.0f0, 1.0f0) +t = range(tspan[1], tspan[2], length = 20) +p = Float32[1.01, 0.9] +ff(du, u, p, t) = -p .* u +prob = SecondOrderODEProblem{false}(ff, du0, u0, tspan, p) +ddu01, du01, dp1 = Zygote.gradient((du0, u0, p) -> sum(Array(solve(prob, Tsit5(), + u0 = ArrayPartition(du0, + u0), + p = p, saveat = t, + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())))), + du0, u0, p) +ddu02, du02, dp2 = Zygote.gradient((du0, u0, p) -> sum(Array(solve(prob, Tsit5(), + u0 = ArrayPartition(du0, + u0), + p = p, saveat = t, + sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())))), + du0, u0, p) +ddu03, du03, dp3 = Zygote.gradient((du0, u0, p) -> sum(Array(solve(prob, Tsit5(), + u0 = ArrayPartition(du0, + u0), + p = p, saveat = t, + sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())))), + du0, u0, p) +ddu04, du04, dp4 = Zygote.gradient((du0, u0, p) -> sum(Array(solve(prob, Tsit5(), + u0 = ArrayPartition(du0, + u0), + p = p, saveat = t, + sensealg = ForwardDiffSensitivity()))), + du0, u0, p) +@test ddu01 ≈ ddu02 +@test ddu01 ≈ ddu03 +@test ddu01 ≈ ddu04 +@test du01 ≈ du02 +@test du01 ≈ du03 +@test du01 ≈ du04 +@test dp1 ≈ dp2 +@test dp1 ≈ dp3 +@test dp1 ≈ dp4 diff --git a/test/shadowing.jl b/test/shadowing.jl index ab0a8bcaa..e080c1c98 100644 --- a/test/shadowing.jl +++ b/test/shadowing.jl @@ -1,4 +1,5 @@ -using Random; Random.seed!(1238) +using Random; +Random.seed!(1238); using OrdinaryDiffEq using Statistics using SciMLSensitivity @@ -6,503 +7,599 @@ using Test using Zygote @testset "LSS" begin - @info "LSS" - @testset "Lorenz single parameter" begin - function lorenz!(du,u,p,t) - du[1] = 10*(u[2]-u[1]) - du[2] = u[1]*(p[1]-u[3]) - u[2] - du[3] = u[1]*u[2] - (8//3)*u[3] + @info "LSS" + @testset "Lorenz single parameter" begin + function lorenz!(du, u, p, t) + du[1] = 10 * (u[2] - u[1]) + du[2] = u[1] * (p[1] - u[3]) - u[2] + du[3] = u[1] * u[2] - (8 // 3) * u[3] + end + + p = [28.0] + tspan_init = (0.0, 30.0) + tspan_attractor = (30.0, 50.0) + u0 = rand(3) + prob_init = ODEProblem(lorenz!, u0, tspan_init, p) + sol_init = solve(prob_init, Tsit5()) + prob_attractor = ODEProblem(lorenz!, sol_init[end], tspan_attractor, p) + sol_attractor = solve(prob_attractor, Vern9(), abstol = 1e-14, reltol = 1e-14) + + g(u, p, t) = u[end] + function dg(out, u, p, t, i) + fill!(out, zero(eltype(u))) + out[end] = one(eltype(u)) + end + lss_problem1 = ForwardLSSProblem(sol_attractor, ForwardLSS(g = g)) + lss_problem1a = ForwardLSSProblem(sol_attractor, ForwardLSS(g = g), + dg_continuous = dg) + lss_problem2 = ForwardLSSProblem(sol_attractor, + ForwardLSS(LSSregularizer = SciMLSensitivity.Cos2Windowing(), + g = g)) + lss_problem2a = ForwardLSSProblem(sol_attractor, + ForwardLSS(LSSregularizer = SciMLSensitivity.Cos2Windowing()), + dg_continuous = dg) + lss_problem3 = ForwardLSSProblem(sol_attractor, + ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)) + lss_problem3a = ForwardLSSProblem(sol_attractor, + ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g), dg_continuous = dg) #ForwardLSS with time dilation requires knowledge of g + + adjointlss_problem = AdjointLSSProblem(sol_attractor, + AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)) + adjointlss_problem_a = AdjointLSSProblem(sol_attractor, + AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g), dg_continuous = dg) + + res1 = shadow_forward(lss_problem1) + res1a = shadow_forward(lss_problem1a) + res2 = shadow_forward(lss_problem2) + res2a = shadow_forward(lss_problem2a) + res3 = shadow_forward(lss_problem3) + res3a = shadow_forward(lss_problem3a) + + res4 = shadow_adjoint(adjointlss_problem) + res4a = shadow_adjoint(adjointlss_problem_a) + + @test res1[1]≈1 atol=1e-1 + @test res2[1]≈1 atol=1e-1 + @test res3[1]≈1 atol=5e-2 + + @test res1≈res1a atol=1e-10 + @test res2≈res2a atol=1e-10 + @test res3≈res3a atol=1e-10 + @test res3≈res4 atol=1e-10 + @test res3≈res4a atol=1e-10 + + # fixed saveat to compare with concrete solve + sol_attractor2 = solve(prob_attractor, Vern9(), abstol = 1e-14, reltol = 1e-14, + saveat = 0.01) + lss_problem1 = ForwardLSSProblem(sol_attractor2, ForwardLSS(g = g)) + lss_problem1a = ForwardLSSProblem(sol_attractor2, ForwardLSS(g = g), + dg_continuous = dg) + lss_problem2 = ForwardLSSProblem(sol_attractor2, + ForwardLSS(LSSregularizer = SciMLSensitivity.Cos2Windowing(), + g = g)) + lss_problem2a = ForwardLSSProblem(sol_attractor2, + ForwardLSS(LSSregularizer = SciMLSensitivity.Cos2Windowing()), + dg_continuous = dg) + lss_problem3 = ForwardLSSProblem(sol_attractor2, + ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)) + lss_problem3a = ForwardLSSProblem(sol_attractor2, + ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g), dg_continuous = dg) #ForwardLSS with time dilation requires knowledge of g + + adjointlss_problem = AdjointLSSProblem(sol_attractor2, + AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)) + adjointlss_problem_a = AdjointLSSProblem(sol_attractor2, + AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g), dg_continuous = dg) + + res1 = shadow_forward(lss_problem1) + res1a = shadow_forward(lss_problem1a) + res2 = shadow_forward(lss_problem2) + res2a = shadow_forward(lss_problem2a) + res3 = shadow_forward(lss_problem3) + res3a = shadow_forward(lss_problem3a) + + res4 = shadow_adjoint(adjointlss_problem) + res4a = shadow_adjoint(adjointlss_problem_a) + + @test res1[1]≈1 atol=5e-2 + @test res2[1]≈1 atol=5e-2 + @test res3[1]≈1 atol=5e-2 + + @test res1≈res1a atol=1e-10 + @test res2≈res2a atol=1e-10 + @test res3≈res3a atol=1e-10 + @test res3≈res4 atol=1e-10 + @test res3≈res4a atol=1e-10 + + function G(p; sensealg = ForwardLSS(g = g), dt = 0.01) + _prob = remake(prob_attractor, p = p) + _sol = solve(_prob, Vern9(), abstol = 1e-14, reltol = 1e-14, saveat = dt, + sensealg = sensealg) + sum(getindex.(_sol.u, 3)) + end + + dp1 = Zygote.gradient((p) -> G(p), p) + @test res1≈dp1[1] atol=1e-10 + + dp1 = Zygote.gradient((p) -> G(p, + sensealg = ForwardLSS(LSSregularizer = SciMLSensitivity.Cos2Windowing())), + p) + @test res2≈dp1[1] atol=1e-10 + + dp1 = Zygote.gradient((p) -> G(p, + sensealg = ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)), p) + @test res3≈dp1[1] atol=1e-10 + + dp1 = Zygote.gradient((p) -> G(p, + sensealg = AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)), p) + @test res4≈dp1[1] atol=1e-10 + + @show res1[1] res2[1] res3[1] end - p = [28.0] - tspan_init = (0.0,30.0) - tspan_attractor = (30.0,50.0) - u0 = rand(3) - prob_init = ODEProblem(lorenz!,u0,tspan_init,p) - sol_init = solve(prob_init,Tsit5()) - prob_attractor = ODEProblem(lorenz!,sol_init[end],tspan_attractor,p) - sol_attractor = solve(prob_attractor,Vern9(),abstol=1e-14,reltol=1e-14) - - g(u,p,t) = u[end] - function dg(out,u,p,t,i) - fill!(out, zero(eltype(u))) - out[end] = one(eltype(u)) + @testset "Lorenz" begin + function lorenz!(du, u, p, t) + du[1] = p[1] * (u[2] - u[1]) + du[2] = u[1] * (p[2] - u[3]) - u[2] + du[3] = u[1] * u[2] - p[3] * u[3] + end + + p = [10.0, 28.0, 8 / 3] + + tspan_init = (0.0, 30.0) + tspan_attractor = (30.0, 50.0) + u0 = rand(3) + prob_init = ODEProblem(lorenz!, u0, tspan_init, p) + sol_init = solve(prob_init, Tsit5()) + prob_attractor = ODEProblem(lorenz!, sol_init[end], tspan_attractor, p) + sol_attractor = solve(prob_attractor, Vern9(), abstol = 1e-14, reltol = 1e-14) + + g(u, p, t) = u[end] + sum(p) + function dgu(out, u, p, t, i) + fill!(out, zero(eltype(u))) + out[end] = one(eltype(u)) + end + function dgp(out, u, p, t, i) + fill!(out, one(eltype(p))) + end + + lss_problem = ForwardLSSProblem(sol_attractor, + ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)) + lss_problem_a = ForwardLSSProblem(sol_attractor, + ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g), dg_continuous = (dgu, dgp)) + adjointlss_problem = AdjointLSSProblem(sol_attractor, + AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)) + adjointlss_problem_a = AdjointLSSProblem(sol_attractor, + AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g), + dg_continuous = (dgu, dgp)) + + resfw = shadow_forward(lss_problem) + resfw_a = shadow_forward(lss_problem_a) + resadj = shadow_adjoint(adjointlss_problem) + resadj_a = shadow_adjoint(adjointlss_problem_a) + + @test resfw≈resadj rtol=1e-10 + @test resfw≈resfw_a rtol=1e-10 + @test resfw≈resadj_a rtol=1e-10 + + sol_attractor2 = solve(prob_attractor, Vern9(), abstol = 1e-14, reltol = 1e-14, + saveat = 0.01) + lss_problem = ForwardLSSProblem(sol_attractor2, + ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)) + resfw = shadow_forward(lss_problem) + + function G(p; sensealg = ForwardLSS(), dt = 0.01) + _prob = remake(prob_attractor, p = p) + _sol = solve(_prob, Vern9(), abstol = 1e-14, reltol = 1e-14, saveat = dt, + sensealg = sensealg) + sum(getindex.(_sol.u, 3)) + sum(p) + end + + dp1 = Zygote.gradient((p) -> G(p, + sensealg = ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)), p) + @test resfw≈dp1[1] atol=1e-10 + + dp1 = Zygote.gradient((p) -> G(p, + sensealg = AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)), p) + @test resfw≈dp1[1] atol=1e-10 + + @show resfw end - lss_problem1 = ForwardLSSProblem(sol_attractor, ForwardLSS(g=g)) - lss_problem1a = ForwardLSSProblem(sol_attractor, ForwardLSS(g=g), dg_continuous = dg) - lss_problem2 = ForwardLSSProblem(sol_attractor, ForwardLSS(LSSregularizer=SciMLSensitivity.Cos2Windowing(),g=g)) - lss_problem2a = ForwardLSSProblem(sol_attractor, ForwardLSS(LSSregularizer=SciMLSensitivity.Cos2Windowing()), dg_continuous = dg) - lss_problem3 = ForwardLSSProblem(sol_attractor, ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)) - lss_problem3a = ForwardLSSProblem(sol_attractor, ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g), dg_continuous = dg) #ForwardLSS with time dilation requires knowledge of g - - adjointlss_problem = AdjointLSSProblem(sol_attractor, AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)) - adjointlss_problem_a = AdjointLSSProblem(sol_attractor, AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g), dg_continuous = dg) - - res1 = shadow_forward(lss_problem1) - res1a = shadow_forward(lss_problem1a) - res2 = shadow_forward(lss_problem2) - res2a = shadow_forward(lss_problem2a) - res3 = shadow_forward(lss_problem3) - res3a = shadow_forward(lss_problem3a) - - res4 = shadow_adjoint(adjointlss_problem) - res4a = shadow_adjoint(adjointlss_problem_a) - - @test res1[1] ≈ 1 atol=1e-1 - @test res2[1] ≈ 1 atol=1e-1 - @test res3[1] ≈ 1 atol=5e-2 - - @test res1 ≈ res1a atol=1e-10 - @test res2 ≈ res2a atol=1e-10 - @test res3 ≈ res3a atol=1e-10 - @test res3 ≈ res4 atol=1e-10 - @test res3 ≈ res4a atol=1e-10 - - # fixed saveat to compare with concrete solve - sol_attractor2 = solve(prob_attractor,Vern9(),abstol=1e-14,reltol=1e-14, saveat=0.01) - lss_problem1 = ForwardLSSProblem(sol_attractor2, ForwardLSS(g=g)) - lss_problem1a = ForwardLSSProblem(sol_attractor2, ForwardLSS(g=g), dg_continuous = dg) - lss_problem2 = ForwardLSSProblem(sol_attractor2, ForwardLSS(LSSregularizer=SciMLSensitivity.Cos2Windowing(),g=g)) - lss_problem2a = ForwardLSSProblem(sol_attractor2, ForwardLSS(LSSregularizer=SciMLSensitivity.Cos2Windowing()), dg_continuous = dg) - lss_problem3 = ForwardLSSProblem(sol_attractor2, ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)) - lss_problem3a = ForwardLSSProblem(sol_attractor2, ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g), dg_continuous = dg) #ForwardLSS with time dilation requires knowledge of g - - adjointlss_problem = AdjointLSSProblem(sol_attractor2, AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)) - adjointlss_problem_a = AdjointLSSProblem(sol_attractor2, AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g), dg_continuous = dg) - - res1 = shadow_forward(lss_problem1) - res1a = shadow_forward(lss_problem1a) - res2 = shadow_forward(lss_problem2) - res2a = shadow_forward(lss_problem2a) - res3 = shadow_forward(lss_problem3) - res3a = shadow_forward(lss_problem3a) - - res4 = shadow_adjoint(adjointlss_problem) - res4a = shadow_adjoint(adjointlss_problem_a) - - @test res1[1] ≈ 1 atol=5e-2 - @test res2[1] ≈ 1 atol=5e-2 - @test res3[1] ≈ 1 atol=5e-2 - - @test res1 ≈ res1a atol=1e-10 - @test res2 ≈ res2a atol=1e-10 - @test res3 ≈ res3a atol=1e-10 - @test res3 ≈ res4 atol=1e-10 - @test res3 ≈ res4a atol=1e-10 - - function G(p; sensealg=ForwardLSS(g=g), dt=0.01) - _prob = remake(prob_attractor,p=p) - _sol = solve(_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=dt,sensealg=sensealg) - sum(getindex.(_sol.u,3)) - end - - dp1 = Zygote.gradient((p)->G(p),p) - @test res1 ≈ dp1[1] atol=1e-10 - - dp1 = Zygote.gradient((p)->G(p, sensealg=ForwardLSS(LSSregularizer=SciMLSensitivity.Cos2Windowing())),p) - @test res2 ≈ dp1[1] atol=1e-10 - - dp1 = Zygote.gradient((p)->G(p, sensealg=ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)),p) - @test res3 ≈ dp1[1] atol=1e-10 - - dp1 = Zygote.gradient((p)->G(p, sensealg=AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)),p) - @test res4 ≈ dp1[1] atol=1e-10 - - @show res1[1] res2[1] res3[1] - end - - @testset "Lorenz" begin - function lorenz!(du,u,p,t) - du[1] = p[1]*(u[2]-u[1]) - du[2] = u[1]*(p[2]-u[3]) - u[2] - du[3] = u[1]*u[2] - p[3]*u[3] - end - - p = [10.0, 28.0, 8/3] - - tspan_init = (0.0,30.0) - tspan_attractor = (30.0,50.0) - u0 = rand(3) - prob_init = ODEProblem(lorenz!,u0,tspan_init,p) - sol_init = solve(prob_init,Tsit5()) - prob_attractor = ODEProblem(lorenz!,sol_init[end],tspan_attractor,p) - sol_attractor = solve(prob_attractor,Vern9(),abstol=1e-14,reltol=1e-14) - - g(u,p,t) = u[end] + sum(p) - function dgu(out,u,p,t,i) - fill!(out, zero(eltype(u))) - out[end] = one(eltype(u)) - end - function dgp(out,u,p,t,i) - fill!(out, one(eltype(p))) - end - - lss_problem = ForwardLSSProblem(sol_attractor, ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)) - lss_problem_a = ForwardLSSProblem(sol_attractor, ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g), dg_continuous = (dgu,dgp)) - adjointlss_problem = AdjointLSSProblem(sol_attractor, AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)) - adjointlss_problem_a = AdjointLSSProblem(sol_attractor, AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g), dg_continuous = (dgu,dgp)) - - resfw = shadow_forward(lss_problem) - resfw_a = shadow_forward(lss_problem_a) - resadj = shadow_adjoint(adjointlss_problem) - resadj_a = shadow_adjoint(adjointlss_problem_a) - @test resfw ≈ resadj rtol=1e-10 - @test resfw ≈ resfw_a rtol=1e-10 - @test resfw ≈ resadj_a rtol=1e-10 - - sol_attractor2 = solve(prob_attractor,Vern9(),abstol=1e-14,reltol=1e-14, saveat=0.01) - lss_problem = ForwardLSSProblem(sol_attractor2, ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)) - resfw = shadow_forward(lss_problem) - - function G(p; sensealg=ForwardLSS(), dt=0.01) - _prob = remake(prob_attractor,p=p) - _sol = solve(_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=dt,sensealg=sensealg) - sum(getindex.(_sol.u,3)) + sum(p) - end - - dp1 = Zygote.gradient((p)->G(p, sensealg=ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)),p) - @test resfw ≈ dp1[1] atol=1e-10 - - dp1 = Zygote.gradient((p)->G(p, sensealg=AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)),p) - @test resfw ≈ dp1[1] atol=1e-10 - - @show resfw - end - - @testset "T0skip and T1skip" begin - function lorenz!(du,u,p,t) - du[1] = p[1]*(u[2]-u[1]) - du[2] = u[1]*(p[2]-u[3]) - u[2] - du[3] = u[1]*u[2] - p[3]*u[3] + @testset "T0skip and T1skip" begin + function lorenz!(du, u, p, t) + du[1] = p[1] * (u[2] - u[1]) + du[2] = u[1] * (p[2] - u[3]) - u[2] + du[3] = u[1] * u[2] - p[3] * u[3] + end + + p = [10.0, 28.0, 8 / 3] + + tspan_init = (0.0, 30.0) + tspan_attractor = (30.0, 50.0) + u0 = rand(3) + prob_init = ODEProblem(lorenz!, u0, tspan_init, p) + sol_init = solve(prob_init, Tsit5()) + prob_attractor = ODEProblem(lorenz!, sol_init[end], tspan_attractor, p) + sol_attractor = solve(prob_attractor, Vern9(), abstol = 1e-14, reltol = 1e-14, + saveat = 0.01) + + g(u, p, t) = u[end]^2 / 2 + sum(p) + function dgu(out, u, p, t, i) + fill!(out, zero(eltype(u))) + out[end] = u[end] + end + function dgp(out, u, p, t, i) + fill!(out, one(eltype(p))) + end + + function G(p; sensealg = ForwardLSS(g = g), dt = 0.01) + _prob = remake(prob_attractor, p = p) + _sol = solve(_prob, Vern9(), abstol = 1e-14, reltol = 1e-14, saveat = dt, + sensealg = sensealg) + sum(getindex.(_sol.u, 3) .^ 2) / 2 + sum(p) + end + + ## ForwardLSS + + lss_problem = ForwardLSSProblem(sol_attractor, + ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)) + resfw = shadow_forward(lss_problem) + + res = deepcopy(resfw) + + dp1 = Zygote.gradient((p) -> G(p, + sensealg = ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)), p) + @test res≈dp1[1] atol=1e-10 + + resfw = shadow_forward(lss_problem, + sensealg = ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0, + 10.0, + 5.0), + g = g)) + resskip = deepcopy(resfw) + + dp1 = Zygote.gradient((p) -> G(p, + sensealg = ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0, + 10.0, + 5.0), + g = g)), p) + @test resskip≈dp1[1] atol=1e-10 + + @show res resskip + + ## ForwardLSS with dgdu and dgdp + + lss_problem = ForwardLSSProblem(sol_attractor, + ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g), dg_continuous = (dgu, dgp)) + res2 = shadow_forward(lss_problem) + @test res≈res2 atol=1e-10 + res2 = shadow_forward(lss_problem, + sensealg = ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0, + 10.0, + 5.0), + g = g)) + @test resskip≈res2 atol=1e-10 + + ## AdjointLSS + + lss_problem = AdjointLSSProblem(sol_attractor, + AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)) + res2 = shadow_adjoint(lss_problem) + @test res≈res2 atol=1e-10 + res2 = shadow_adjoint(lss_problem, + sensealg = AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0, + 10.0, + 5.0), + g = g)) + @test_broken resskip≈res2 atol=1e-10 + + dp1 = Zygote.gradient((p) -> G(p, + sensealg = AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)), p) + @test res≈dp1[1] atol=1e-10 + + dp1 = Zygote.gradient((p) -> G(p, + sensealg = AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0, + 10.0, + 5.0), + g = g)), p) + @test res2≈dp1[1] atol=1e-10 + + ## AdjointLSS with dgdu and dgd + + lss_problem = AdjointLSSProblem(sol_attractor, + AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g), dg_continuous = (dgu, dgp)) + res2 = shadow_adjoint(lss_problem) + @test res≈res2 atol=1e-10 + res2 = shadow_adjoint(lss_problem, + sensealg = AdjointLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0, + 10.0, + 5.0), + g = g)) + @test_broken resskip≈res2 atol=1e-10 end - - p = [10.0, 28.0, 8/3] - - tspan_init = (0.0,30.0) - tspan_attractor = (30.0,50.0) - u0 = rand(3) - prob_init = ODEProblem(lorenz!,u0,tspan_init,p) - sol_init = solve(prob_init,Tsit5()) - prob_attractor = ODEProblem(lorenz!,sol_init[end],tspan_attractor,p) - sol_attractor = solve(prob_attractor,Vern9(),abstol=1e-14,reltol=1e-14, saveat=0.01) - - g(u,p,t) = u[end]^2/2 + sum(p) - function dgu(out,u,p,t,i) - fill!(out, zero(eltype(u))) - out[end] = u[end] - end - function dgp(out,u,p,t,i) - fill!(out, one(eltype(p))) - end - - function G(p; sensealg=ForwardLSS(g=g), dt=0.01) - _prob = remake(prob_attractor,p=p) - _sol = solve(_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=dt,sensealg=sensealg) - sum(getindex.(_sol.u,3).^2)/2 + sum(p) - end - - ## ForwardLSS - - lss_problem = ForwardLSSProblem(sol_attractor, ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0), g=g)) - resfw = shadow_forward(lss_problem) - - res = deepcopy(resfw) - - dp1 = Zygote.gradient((p)->G(p, sensealg=ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)),p) - @test res ≈ dp1[1] atol=1e-10 - - resfw = shadow_forward(lss_problem, sensealg = ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0,10.0,5.0), g=g)) - resskip = deepcopy(resfw) - - dp1 = Zygote.gradient((p)->G(p, sensealg=ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0,10.0,5.0), g=g)),p) - @test resskip ≈ dp1[1] atol=1e-10 - - @show res resskip - - ## ForwardLSS with dgdu and dgdp - - lss_problem = ForwardLSSProblem(sol_attractor, ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g), dg_continuous = (dgu,dgp)) - res2 = shadow_forward(lss_problem) - @test res ≈ res2 atol=1e-10 - res2 = shadow_forward(lss_problem, sensealg = ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0,10.0,5.0), g=g)) - @test resskip ≈ res2 atol=1e-10 - - ## AdjointLSS - - lss_problem = AdjointLSSProblem(sol_attractor, AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)) - res2 = shadow_adjoint(lss_problem) - @test res ≈ res2 atol=1e-10 - res2 = shadow_adjoint(lss_problem, sensealg = AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0,10.0,5.0), g=g)) - @test_broken resskip ≈ res2 atol=1e-10 - - dp1 = Zygote.gradient((p)->G(p, sensealg=AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)),p) - @test res ≈ dp1[1] atol=1e-10 - - dp1 = Zygote.gradient((p)->G(p, sensealg=AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0,10.0,5.0), g=g)),p) - @test res2 ≈ dp1[1] atol=1e-10 - - ## AdjointLSS with dgdu and dgd - - lss_problem = AdjointLSSProblem(sol_attractor, AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0), g=g), dg_continuous = (dgu,dgp)) - res2 = shadow_adjoint(lss_problem) - @test res ≈ res2 atol=1e-10 - res2 = shadow_adjoint(lss_problem, sensealg = AdjointLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0,10.0,5.0), g=g)) - @test_broken resskip ≈ res2 atol=1e-10 - end end @testset "NILSS" begin - @info "NILSS" - @testset "Lorenz single parameter" begin - function lorenz!(du,u,p,t) - du[1] = 10*(u[2]-u[1]) - du[2] = u[1]*(p[1]-u[3]) - u[2] - du[3] = u[1]*u[2] - (8//3)*u[3] - end - - p = [28.0] - tspan_init = (0.0,100.0) - tspan_attractor = (100.0,120.0) - u0 = rand(3) - prob_init = ODEProblem(lorenz!,u0,tspan_init,p) - sol_init = solve(prob_init,Tsit5()) - prob_attractor = ODEProblem(lorenz!,sol_init[end],tspan_attractor,p) - - g(u,p,t) = u[end] - function dg(out,u,p,t,i) - fill!(out, zero(eltype(u))) - out[end] = one(eltype(u)) - end - - nseg = 50 # number of segments on time interval - nstep = 2001 # number of steps on each segment - - # fix seed here for res1==res2 check, otherwise hom. tangent - # are initialized randomly - Random.seed!(1234) - nilss_prob1 = NILSSProblem(prob_attractor, NILSS(nseg, nstep, g=g)) - res1 = SciMLSensitivity.shadow_forward(nilss_prob1,Tsit5()) - - Random.seed!(1234) - nilss_prob2 = NILSSProblem(prob_attractor, NILSS(nseg, nstep, g=g), dg_continuous = dg) - res2 = SciMLSensitivity.shadow_forward(nilss_prob2,Tsit5()) - - @test res1[1] ≈ 1 atol=5e-2 - @test res2[1] ≈ 1 atol=5e-2 - @test res1 ≈ res2 atol=1e-10 - - function G(p; dt=nilss_prob1.dtsave) - _prob = remake(prob_attractor,p=p) - _sol = solve(_prob,Tsit5(),saveat=dt,sensealg=NILSS(nseg, nstep, g=g)) - sum(getindex.(_sol.u,3)) + @info "NILSS" + @testset "Lorenz single parameter" begin + function lorenz!(du, u, p, t) + du[1] = 10 * (u[2] - u[1]) + du[2] = u[1] * (p[1] - u[3]) - u[2] + du[3] = u[1] * u[2] - (8 // 3) * u[3] + end + + p = [28.0] + tspan_init = (0.0, 100.0) + tspan_attractor = (100.0, 120.0) + u0 = rand(3) + prob_init = ODEProblem(lorenz!, u0, tspan_init, p) + sol_init = solve(prob_init, Tsit5()) + prob_attractor = ODEProblem(lorenz!, sol_init[end], tspan_attractor, p) + + g(u, p, t) = u[end] + function dg(out, u, p, t, i) + fill!(out, zero(eltype(u))) + out[end] = one(eltype(u)) + end + + nseg = 50 # number of segments on time interval + nstep = 2001 # number of steps on each segment + + # fix seed here for res1==res2 check, otherwise hom. tangent + # are initialized randomly + Random.seed!(1234) + nilss_prob1 = NILSSProblem(prob_attractor, NILSS(nseg, nstep, g = g)) + res1 = SciMLSensitivity.shadow_forward(nilss_prob1, Tsit5()) + + Random.seed!(1234) + nilss_prob2 = NILSSProblem(prob_attractor, NILSS(nseg, nstep, g = g), + dg_continuous = dg) + res2 = SciMLSensitivity.shadow_forward(nilss_prob2, Tsit5()) + + @test res1[1]≈1 atol=5e-2 + @test res2[1]≈1 atol=5e-2 + @test res1≈res2 atol=1e-10 + + function G(p; dt = nilss_prob1.dtsave) + _prob = remake(prob_attractor, p = p) + _sol = solve(_prob, Tsit5(), saveat = dt, sensealg = NILSS(nseg, nstep, g = g)) + sum(getindex.(_sol.u, 3)) + end + + Random.seed!(1234) + dp1 = Zygote.gradient((p) -> G(p), p) + @test res1≈dp1[1] atol=1e-10 end - Random.seed!(1234) - dp1 = Zygote.gradient((p)->G(p),p) - @test res1 ≈ dp1[1] atol=1e-10 - end - - @testset "Lorenz" begin - # Here we test LSS output to NILSS output w/ multiple params - function lorenz!(du,u,p,t) - du[1] = p[1]*(u[2]-u[1]) - du[2] = u[1]*(p[2]-u[3]) - u[2] - du[3] = u[1]*u[2] - p[3]*u[3] - end + @testset "Lorenz" begin + # Here we test LSS output to NILSS output w/ multiple params + function lorenz!(du, u, p, t) + du[1] = p[1] * (u[2] - u[1]) + du[2] = u[1] * (p[2] - u[3]) - u[2] + du[3] = u[1] * u[2] - p[3] * u[3] + end - p = [10.0, 28.0, 8/3] - u0 = rand(3) + p = [10.0, 28.0, 8 / 3] + u0 = rand(3) - # Relatively short tspan_attractor since increasing more infeasible w/ - # computational cost of LSS - tspan_init = (0.0,100.0) - tspan_attractor = (100.0,120.0) + # Relatively short tspan_attractor since increasing more infeasible w/ + # computational cost of LSS + tspan_init = (0.0, 100.0) + tspan_attractor = (100.0, 120.0) - prob_init = ODEProblem(lorenz!,u0,tspan_init,p) - sol_init = solve(prob_init,Tsit5()) + prob_init = ODEProblem(lorenz!, u0, tspan_init, p) + sol_init = solve(prob_init, Tsit5()) - prob_attractor = ODEProblem(lorenz!,sol_init[end],tspan_attractor,p) - sol_attractor = solve(prob_attractor,Vern9(),abstol=1e-14,reltol=1e-14) + prob_attractor = ODEProblem(lorenz!, sol_init[end], tspan_attractor, p) + sol_attractor = solve(prob_attractor, Vern9(), abstol = 1e-14, reltol = 1e-14) - g(u,p,t) = u[end] + g(u, p, t) = u[end] - lss_problem = ForwardLSSProblem(sol_attractor, ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0),g=g)) + lss_problem = ForwardLSSProblem(sol_attractor, + ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g)) - resfw = shadow_forward(lss_problem) + resfw = shadow_forward(lss_problem) - # NILSS can handle w/ longer timespan and get lower noise in sensitivity estimate - tspan_init = (0.0,100.0) - tspan_attractor = (100.0,150.0) + # NILSS can handle w/ longer timespan and get lower noise in sensitivity estimate + tspan_init = (0.0, 100.0) + tspan_attractor = (100.0, 150.0) - prob_init = ODEProblem(lorenz!,u0,tspan_init,p) - sol_init = solve(prob_init,Tsit5()) + prob_init = ODEProblem(lorenz!, u0, tspan_init, p) + sol_init = solve(prob_init, Tsit5()) - prob_attractor = ODEProblem(lorenz!,sol_init[end],tspan_attractor,p) - sol_attractor = solve(prob_attractor,Vern9(),abstol=1e-14,reltol=1e-14) + prob_attractor = ODEProblem(lorenz!, sol_init[end], tspan_attractor, p) + sol_attractor = solve(prob_attractor, Vern9(), abstol = 1e-14, reltol = 1e-14) - nseg = 50 # number of segments on time interval - nstep = 2001 # number of steps on each segment + nseg = 50 # number of segments on time interval + nstep = 2001 # number of steps on each segment - nilss_prob = NILSSProblem(prob_attractor, NILSS(nseg, nstep; g)); - res = shadow_forward(nilss_prob, Tsit5()) + nilss_prob = NILSSProblem(prob_attractor, NILSS(nseg, nstep; g)) + res = shadow_forward(nilss_prob, Tsit5()) - # There is larger noise in LSS estimate of parameter 3 due to shorter timespan considered, - # so test tolerance for parameter 3 is larger. - @test resfw[1] ≈ res[1] atol=5e-2 - @test resfw[2] ≈ res[2] atol=5e-2 - @test resfw[3] ≈ res[3] atol=5e-1 - end + # There is larger noise in LSS estimate of parameter 3 due to shorter timespan considered, + # so test tolerance for parameter 3 is larger. + @test resfw[1]≈res[1] atol=5e-2 + @test resfw[2]≈res[2] atol=5e-2 + @test resfw[3]≈res[3] atol=5e-1 + end end @testset "NILSAS" begin - @info "NILSAS" - @testset "nilsas_min function" begin - u0 = rand(3) - M = 2 - nseg = 2 - numparams = 1 - quadcache = SciMLSensitivity.QuadratureCache(u0, M, nseg, numparams) - - C = quadcache.C - C[:,:,1] .= [ - 1. 0. - 0. 1.] - C[:,:,2] .= [ - 4. 0. - 0. 1.] - - dwv = quadcache.dwv - dwv[:,1] .= [1., 0.] - dwv[:,2] .= [1., 4.] - - dwf = quadcache.dwf - dwf[:,1] .= [1., 1.] - dwf[:,2] .= [3., 1.] - - dvf = quadcache.dvf - dvf[1] = 1. - dvf[2] = 2. - - R = quadcache.R - R[:,:,1] .= [ - Inf Inf - Inf Inf] - R[:,:,2] .= [ - 1. 1. - 0. 2.] - - b = quadcache.b - b[:,1] = [Inf, Inf] - b[:,2] = [0., 1.] - - @test SciMLSensitivity.nilsas_min(quadcache) ≈ [-1. 0. - -1. -1.] - end - @testset "Lorenz" begin - function lorenz!(du,u,p,t) - du[1] = p[1]*(u[2]-u[1]) - du[2] = u[1]*(p[2]-u[3]) - u[2] - du[3] = u[1]*u[2] - p[3]*u[3] - return nothing + @info "NILSAS" + @testset "nilsas_min function" begin + u0 = rand(3) + M = 2 + nseg = 2 + numparams = 1 + quadcache = SciMLSensitivity.QuadratureCache(u0, M, nseg, numparams) + + C = quadcache.C + C[:, :, 1] .= [1.0 0.0 + 0.0 1.0] + C[:, :, 2] .= [4.0 0.0 + 0.0 1.0] + + dwv = quadcache.dwv + dwv[:, 1] .= [1.0, 0.0] + dwv[:, 2] .= [1.0, 4.0] + + dwf = quadcache.dwf + dwf[:, 1] .= [1.0, 1.0] + dwf[:, 2] .= [3.0, 1.0] + + dvf = quadcache.dvf + dvf[1] = 1.0 + dvf[2] = 2.0 + + R = quadcache.R + R[:, :, 1] .= [Inf Inf + Inf Inf] + R[:, :, 2] .= [1.0 1.0 + 0.0 2.0] + + b = quadcache.b + b[:, 1] = [Inf, Inf] + b[:, 2] = [0.0, 1.0] + + @test SciMLSensitivity.nilsas_min(quadcache) ≈ [-1.0 0.0 + -1.0 -1.0] end + @testset "Lorenz" begin + function lorenz!(du, u, p, t) + du[1] = p[1] * (u[2] - u[1]) + du[2] = u[1] * (p[2] - u[3]) - u[2] + du[3] = u[1] * u[2] - p[3] * u[3] + return nothing + end - u0_trans = rand(3) - p = [10.0, 28.0, 8/3] - - # parameter passing to NILSAS - M = 2 - nseg = 40 - nstep = 101 + u0_trans = rand(3) + p = [10.0, 28.0, 8 / 3] - tspan_transient = (0.0,30.0) - prob_transient = ODEProblem(lorenz!,u0_trans,tspan_transient,p) - sol_transient = solve(prob_transient, Tsit5()) + # parameter passing to NILSAS + M = 2 + nseg = 40 + nstep = 101 - u0 = sol_transient.u[end] + tspan_transient = (0.0, 30.0) + prob_transient = ODEProblem(lorenz!, u0_trans, tspan_transient, p) + sol_transient = solve(prob_transient, Tsit5()) - tspan_attractor = (0.0,40.0) - prob_attractor = ODEProblem(lorenz!,u0,tspan_attractor,p) - sol_attractor = solve(prob_attractor,Vern9(),abstol=1e-14,reltol=1e-14,saveat=0.01) + u0 = sol_transient.u[end] - g(u,p,t) = u[end] - function dg(out,u,p,t,i=nothing) - fill!(out, zero(eltype(u))) - out[end] = one(eltype(u)) - end + tspan_attractor = (0.0, 40.0) + prob_attractor = ODEProblem(lorenz!, u0, tspan_attractor, p) + sol_attractor = solve(prob_attractor, Vern9(), abstol = 1e-14, reltol = 1e-14, + saveat = 0.01) - lss_problem = ForwardLSSProblem(sol_attractor, ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0), g=g), dg_continuous = dg) - resfw = shadow_forward(lss_problem) + g(u, p, t) = u[end] + function dg(out, u, p, t, i = nothing) + fill!(out, zero(eltype(u))) + out[end] = one(eltype(u)) + end - @info resfw + lss_problem = ForwardLSSProblem(sol_attractor, + ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g), dg_continuous = dg) + resfw = shadow_forward(lss_problem) - nilsas_prob = NILSASProblem(sol_attractor, NILSAS(nseg,nstep,M, g=g)) - res = shadow_adjoint(nilsas_prob, Tsit5()) + @info resfw - @info res + nilsas_prob = NILSASProblem(sol_attractor, NILSAS(nseg, nstep, M, g = g)) + res = shadow_adjoint(nilsas_prob, Tsit5()) - @test resfw ≈ res atol=1e-1 + @info res - nilsas_prob = NILSASProblem(sol_attractor, NILSAS(nseg,nstep,M, g=g), dg_continuous = dg) - res = shadow_adjoint(nilsas_prob, Tsit5()) + @test resfw≈res atol=1e-1 - @info res + nilsas_prob = NILSASProblem(sol_attractor, NILSAS(nseg, nstep, M, g = g), + dg_continuous = dg) + res = shadow_adjoint(nilsas_prob, Tsit5()) - @test resfw ≈ res atol=1e-1 - end + @info res - @testset "Lorenz parameter-dependent loss function" begin - function lorenz!(du,u,p,t) - du[1] = p[1]*(u[2]-u[1]) - du[2] = u[1]*(p[2]-u[3]) - u[2] - du[3] = u[1]*u[2] - p[3]*u[3] - return nothing + @test resfw≈res atol=1e-1 end - u0_trans = rand(3) - p = [10.0, 28.0, 8/3] + @testset "Lorenz parameter-dependent loss function" begin + function lorenz!(du, u, p, t) + du[1] = p[1] * (u[2] - u[1]) + du[2] = u[1] * (p[2] - u[3]) - u[2] + du[3] = u[1] * u[2] - p[3] * u[3] + return nothing + end - # parameter passing to NILSAS - M = 2 - nseg = 100 - nstep = 101 + u0_trans = rand(3) + p = [10.0, 28.0, 8 / 3] - tspan_transient = (0.0,100.0) - prob_transient = ODEProblem(lorenz!,u0_trans,tspan_transient,p) - sol_transient = solve(prob_transient, Tsit5()) + # parameter passing to NILSAS + M = 2 + nseg = 100 + nstep = 101 - u0 = sol_transient.u[end] + tspan_transient = (0.0, 100.0) + prob_transient = ODEProblem(lorenz!, u0_trans, tspan_transient, p) + sol_transient = solve(prob_transient, Tsit5()) - tspan_attractor = (0.0,50.0) - prob_attractor = ODEProblem(lorenz!,u0,tspan_attractor,p) - sol_attractor = solve(prob_attractor,Vern9(),abstol=1e-14,reltol=1e-14,saveat=0.01) + u0 = sol_transient.u[end] - g(u,p,t) = u[end]^2/2 + sum(p) - function dgu(out,u,p,t,i=nothing) - fill!(out, zero(eltype(u))) - out[end] = u[end] - end - function dgp(out,u,p,t,i=nothing) - fill!(out, one(eltype(p))) - end + tspan_attractor = (0.0, 50.0) + prob_attractor = ODEProblem(lorenz!, u0, tspan_attractor, p) + sol_attractor = solve(prob_attractor, Vern9(), abstol = 1e-14, reltol = 1e-14, + saveat = 0.01) + + g(u, p, t) = u[end]^2 / 2 + sum(p) + function dgu(out, u, p, t, i = nothing) + fill!(out, zero(eltype(u))) + out[end] = u[end] + end + function dgp(out, u, p, t, i = nothing) + fill!(out, one(eltype(p))) + end - lss_problem = ForwardLSSProblem(sol_attractor, ForwardLSS(LSSregularizer=SciMLSensitivity.TimeDilation(10.0), g=g), dg_continuous = (dgu,dgp)) - resfw = shadow_forward(lss_problem) + lss_problem = ForwardLSSProblem(sol_attractor, + ForwardLSS(LSSregularizer = SciMLSensitivity.TimeDilation(10.0), + g = g), dg_continuous = (dgu, dgp)) + resfw = shadow_forward(lss_problem) - @info resfw + @info resfw - nilsas_prob = NILSASProblem(sol_attractor, NILSAS(nseg,nstep,M, g=g)) - res = shadow_adjoint(nilsas_prob, Tsit5()) + nilsas_prob = NILSASProblem(sol_attractor, NILSAS(nseg, nstep, M, g = g)) + res = shadow_adjoint(nilsas_prob, Tsit5()) - @info res + @info res - @test resfw ≈ res rtol=1e-1 + @test resfw≈res rtol=1e-1 - nilsas_prob = NILSASProblem(sol_attractor, NILSAS(nseg,nstep,M, g=g), dg_continuous = (dgu,dgp)) - res = shadow_adjoint(nilsas_prob, Tsit5()) + nilsas_prob = NILSASProblem(sol_attractor, NILSAS(nseg, nstep, M, g = g), + dg_continuous = (dgu, dgp)) + res = shadow_adjoint(nilsas_prob, Tsit5()) - @info res + @info res - @test resfw ≈ res rtol=1e-1 - end + @test resfw≈res rtol=1e-1 + end end diff --git a/test/size_handling_adjoint.jl b/test/size_handling_adjoint.jl index 0412a19c8..94717551c 100644 --- a/test/size_handling_adjoint.jl +++ b/test/size_handling_adjoint.jl @@ -1,32 +1,32 @@ using SciMLSensitivity, Flux, OrdinaryDiffEq, Test # , Plots -p = [1.5 1.0;3.0 1.0] -function lotka_volterra(du,u,p,t) - du[1] = p[1,1]*u[1] - p[1,2]*u[1]*u[2] - du[2] = -p[2,1]*u[2] + p[2,2]*u[1]*u[2] +p = [1.5 1.0; 3.0 1.0] +function lotka_volterra(du, u, p, t) + du[1] = p[1, 1] * u[1] - p[1, 2] * u[1] * u[2] + du[2] = -p[2, 1] * u[2] + p[2, 2] * u[1] * u[2] end -u0 = [1.0,1.0] -tspan = (0.0,10.0) +u0 = [1.0, 1.0] +tspan = (0.0, 10.0) -prob = ODEProblem(lotka_volterra,u0,tspan,p) -sol = solve(prob,Tsit5()) +prob = ODEProblem(lotka_volterra, u0, tspan, p) +sol = solve(prob, Tsit5()) # plot(sol) -p = [2.2 1.0;2.0 0.4] # Tweaked Initial Parameter Array +p = [2.2 1.0; 2.0 0.4] # Tweaked Initial Parameter Array ps = Flux.params(p) function predict_adjoint() # Our 1-layer neural network - Array(solve(prob,Tsit5(),p=p,saveat=0.0:0.1:10.0)) + Array(solve(prob, Tsit5(), p = p, saveat = 0.0:0.1:10.0)) end -loss_adjoint() = sum(abs2,x-1 for x in predict_adjoint()) +loss_adjoint() = sum(abs2, x - 1 for x in predict_adjoint()) data = Iterators.repeated((), 100) opt = ADAM(0.1) cb = function () #callback function to observe training - display(loss_adjoint()) + display(loss_adjoint()) end predict_adjoint() diff --git a/test/sparse_adjoint.jl b/test/sparse_adjoint.jl index c473bd7e5..f3f29386e 100644 --- a/test/sparse_adjoint.jl +++ b/test/sparse_adjoint.jl @@ -3,31 +3,48 @@ using AlgebraicMultigrid: AlgebraicMultigrid using Test foop(u, p, t) = jac(u, p, t) * u -jac(u, p, t) = spdiagm(0=>p) -paramjac(u, p, t) = SparseArrays.spdiagm(0=>u) -@Zygote.adjoint foop(u, p, t) = foop(u, p, t), delta->(jac(u, p, t)' * delta, paramjac(u, p, t)' * delta, zeros(length(u))) +jac(u, p, t) = spdiagm(0 => p) +paramjac(u, p, t) = SparseArrays.spdiagm(0 => u) +Zygote.@adjoint function foop(u, p, t) + foop(u, p, t), + delta -> (jac(u, p, t)' * delta, paramjac(u, p, t)' * delta, zeros(length(u))) +end n = 2 p = collect(1.0:n) u0 = ones(n) tspan = [0.0, 1] -odef = ODEFunction(foop; jac=jac, jac_prototype=jac(u0, p, 0.0), paramjac=paramjac) -function g_helper(p; alg=Rosenbrock23(linsolve=LUFactorization())) +odef = ODEFunction(foop; jac = jac, jac_prototype = jac(u0, p, 0.0), paramjac = paramjac) +function g_helper(p; alg = Rosenbrock23(linsolve = LUFactorization())) prob = ODEProblem(odef, u0, tspan, p) - soln = Array(solve(prob, alg; u0=prob.u0, p=prob.p, abstol=1e-4, reltol=1e-4, sensealg=InterpolatingAdjoint()))[:, end] + soln = Array(solve(prob, alg; u0 = prob.u0, p = prob.p, abstol = 1e-4, reltol = 1e-4, + sensealg = InterpolatingAdjoint()))[:, end] return soln end function g(p; kwargs...) soln = g_helper(p; kwargs...) return sum(soln) end -@test isapprox(exp.(p), g_helper(p); atol=1e-3, rtol=1e-3) -@test isapprox(exp.(p), Zygote.gradient(g, p)[1]; atol=1e-3, rtol=1e-3) -@test isapprox(exp.(p), g_helper(p; alg=Rosenbrock23(linsolve=KLUFactorization())); atol=1e-3, rtol=1e-3) -@test isapprox(exp.(p), Zygote.gradient(p->g(p; alg=Rosenbrock23(linsolve=KLUFactorization())), p)[1]; atol=1e-3, rtol=1e-3) -@test isapprox(exp.(p), g_helper(p; alg=ImplicitEuler(linsolve=LUFactorization())); atol=1e-1, rtol=1e-1) -@test isapprox(exp.(p), Zygote.gradient(p->g(p; alg=ImplicitEuler(linsolve=LUFactorization())), p)[1]; atol=1e-1, rtol=1e-1) -@test isapprox(exp.(p), g_helper(p; alg=ImplicitEuler(linsolve=UMFPACKFactorization())); atol=1e-1, rtol=1e-1) -@test isapprox(exp.(p), Zygote.gradient(p->g(p; alg=ImplicitEuler(linsolve=UMFPACKFactorization())), p)[1]; atol=1e-1, rtol=1e-1) -@test isapprox(exp.(p), g_helper(p; alg=ImplicitEuler(linsolve=KrylovJL_GMRES())); atol=1e-1, rtol=1e-1) -@test isapprox(exp.(p), Zygote.gradient(p->g(p; alg=ImplicitEuler(linsolve=KrylovJL_GMRES())), p)[1]; atol=1e-1, rtol=1e-1) +@test isapprox(exp.(p), g_helper(p); atol = 1e-3, rtol = 1e-3) +@test isapprox(exp.(p), Zygote.gradient(g, p)[1]; atol = 1e-3, rtol = 1e-3) +@test isapprox(exp.(p), g_helper(p; alg = Rosenbrock23(linsolve = KLUFactorization())); + atol = 1e-3, rtol = 1e-3) +@test isapprox(exp.(p), + Zygote.gradient(p -> g(p; alg = Rosenbrock23(linsolve = KLUFactorization())), + p)[1]; atol = 1e-3, rtol = 1e-3) +@test isapprox(exp.(p), g_helper(p; alg = ImplicitEuler(linsolve = LUFactorization())); + atol = 1e-1, rtol = 1e-1) +@test isapprox(exp.(p), + Zygote.gradient(p -> g(p; alg = ImplicitEuler(linsolve = LUFactorization())), + p)[1]; atol = 1e-1, rtol = 1e-1) +@test isapprox(exp.(p), g_helper(p; alg = ImplicitEuler(linsolve = UMFPACKFactorization())); + atol = 1e-1, rtol = 1e-1) +@test isapprox(exp.(p), + Zygote.gradient(p -> g(p; + alg = ImplicitEuler(linsolve = UMFPACKFactorization())), + p)[1]; atol = 1e-1, rtol = 1e-1) +@test isapprox(exp.(p), g_helper(p; alg = ImplicitEuler(linsolve = KrylovJL_GMRES())); + atol = 1e-1, rtol = 1e-1) +@test isapprox(exp.(p), + Zygote.gradient(p -> g(p; alg = ImplicitEuler(linsolve = KrylovJL_GMRES())), + p)[1]; atol = 1e-1, rtol = 1e-1) diff --git a/test/steady_state.jl b/test/steady_state.jl index 414741003..3049de1e2 100644 --- a/test/steady_state.jl +++ b/test/steady_state.jl @@ -6,272 +6,374 @@ using Random Random.seed!(12345) @testset "Adjoint sensitivities of steady state solver" begin - function f!(du,u,p,t) - du[1] = p[1] + p[2]*u[1] - du[2] = p[3]*u[1] + p[4]*u[2] - end - - function jac!(J,u,p,t) #df/dx - J[1,1] = p[2] - J[2,1] = p[3] - J[1,2] = 0 - J[2,2] = p[4] - nothing - end - - function paramjac!(fp,u,p,t) #df/dp - fp[1,1] = 1 - fp[2,1] = 0 - fp[1,2] = u[1] - fp[2,2] = 0 - fp[1,3] = 0 - fp[2,3] = u[1] - fp[1,4] = 0 - fp[2,4] = u[2] - nothing - end - - function dg!(out,u,p,t,i) - (out.=-2.0.+u) - end - - function g(u,p,t) - sum((2.0.-u).^2)/2 + sum(p.^2)/2 - end - - u0 = zeros(2) - p = [2.0,-2.0,1.0,-4.0] - prob = SteadyStateProblem(f!,u0,p) - abstol = 1e-10 - @testset "for p" begin - println("Calculate adjoint sensitivities from Jacobians") - - sol_analytical = [-p[1]/p[2], p[1]*p[3]/(p[2]*p[4])] - - J = zeros(2,2) - fp = zeros(2,4) - gp = zeros(4) - gx = zeros(1,2) - delg_delp = copy(p) - - jac!(J,sol_analytical,p,nothing) - dg!(vec(gx),sol_analytical,p,nothing,nothing) - paramjac!(fp,sol_analytical,p,nothing) - - lambda = J' \ gx' - res_analytical = delg_delp' - lambda' * fp # = -gx*inv(J)*fp - - @info "Expected result" sol_analytical, res_analytical, delg_delp'-gx*inv(J)*fp - - - @info "Calculate adjoint sensitivities from autodiff & numerical diff" - function G(p) - tmp_prob = remake(prob,u0=convert.(eltype(p),prob.u0),p=p) - sol = solve(tmp_prob, - SSRootfind(nlsolve = (f!,u0,abstol) -> (res=NLsolve.nlsolve(f!,u0,autodiff=:forward,method=:newton,iterations=Int(1e6),ftol=1e-14);res.zero)) - ) - A = convert(Array,sol) - g(A,p,nothing) + function f!(du, u, p, t) + du[1] = p[1] + p[2] * u[1] + du[2] = p[3] * u[1] + p[4] * u[2] end - res1 = ForwardDiff.gradient(G,p) - res2 = Calculus.gradient(G,p) - #@info res1, res2, res_analytical - - @test res1 ≈ res_analytical' rtol = 1e-7 - @test res2 ≈ res_analytical' rtol = 1e-7 - @test res1 ≈ res2 rtol = 1e-7 - - - @info "Adjoint sensitivities" - - # with jac, param_jac - f1 = ODEFunction(f!;jac=jac!, paramjac=paramjac!) - prob1 = SteadyStateProblem(f1,u0,p) - sol1 = solve(prob1,DynamicSS(Rodas5(),reltol=1e-14,abstol=1e-14),reltol=1e-14,abstol=1e-14) - - res1a = adjoint_sensitivities(sol1,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(),g,dg!) - res1b = adjoint_sensitivities(sol1,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(),g,nothing) - res1c = adjoint_sensitivities(sol1,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autodiff=false),g,nothing) - res1d = adjoint_sensitivities(sol1,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=TrackerVJP()),g,nothing) - res1e = adjoint_sensitivities(sol1,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=ReverseDiffVJP()),g,nothing) - res1f = adjoint_sensitivities(sol1,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=ZygoteVJP()),g,nothing) - res1g = adjoint_sensitivities(sol1,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autodiff=false,autojacvec=false),g,nothing) - res1h = adjoint_sensitivities(sol1,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=EnzymeVJP()),g,nothing) - - # with jac, without param_jac - f2 = ODEFunction(f!;jac=jac!) - prob2 = SteadyStateProblem(f2,u0,p) - sol2 = solve(prob2,DynamicSS(Rodas5(),reltol=1e-14,abstol=1e-14),reltol=1e-14,abstol=1e-14) - res2a = adjoint_sensitivities(sol2,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(),g,dg!) - res2b = adjoint_sensitivities(sol2,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(),g,nothing) - res2c = adjoint_sensitivities(sol2,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autodiff=false),g,nothing) - res2d = adjoint_sensitivities(sol2,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=TrackerVJP()),g,nothing) - res2e = adjoint_sensitivities(sol2,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=ReverseDiffVJP()),g,nothing) - res2f = adjoint_sensitivities(sol2,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=ZygoteVJP()),g,nothing) - res2g = adjoint_sensitivities(sol2,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autodiff=false,autojacvec=false),g,nothing) - res2h = adjoint_sensitivities(sol2,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=EnzymeVJP()),g,nothing) - - # without jac, without param_jac - f3 = ODEFunction(f!) - prob3 = SteadyStateProblem(f3,u0,p) - sol3 = solve(prob3,DynamicSS(Rodas5(),reltol=1e-14,abstol=1e-14),reltol=1e-14,abstol=1e-14) - res3a = adjoint_sensitivities(sol3,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(),g,dg!) - res3b = adjoint_sensitivities(sol3,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(),g,nothing) - res3c = adjoint_sensitivities(sol3,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autodiff=false),g,nothing) - res3d = adjoint_sensitivities(sol3,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=TrackerVJP()),g,nothing) - res3e = adjoint_sensitivities(sol3,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=ReverseDiffVJP()),g,nothing) - res3f = adjoint_sensitivities(sol3,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=ZygoteVJP()),g,nothing) - res3g = adjoint_sensitivities(sol3,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autodiff=false,autojacvec=false),g,nothing) - res3h = adjoint_sensitivities(sol3,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=EnzymeVJP()),g,nothing) - - @test norm(res_analytical' .- res1a) < 1e-7 - @test norm(res_analytical' .- res1b) < 1e-7 - @test norm(res_analytical' .- res1c) < 1e-7 - @test norm(res_analytical' .- res1d) < 1e-7 - @test norm(res_analytical' .- res1e) < 1e-7 - @test norm(res_analytical' .- res1f) < 1e-7 - @test norm(res_analytical' .- res1g) < 1e-7 - @test norm(res_analytical' .- res1h) < 1e-7 - @test norm(res_analytical' .- res2a) < 1e-7 - @test norm(res_analytical' .- res2b) < 1e-7 - @test norm(res_analytical' .- res2c) < 1e-7 - @test norm(res_analytical' .- res2d) < 1e-7 - @test norm(res_analytical' .- res2e) < 1e-7 - @test norm(res_analytical' .- res2f) < 1e-7 - @test norm(res_analytical' .- res2g) < 1e-7 - @test norm(res_analytical' .- res2h) < 1e-7 - @test norm(res_analytical' .- res3a) < 1e-7 - @test norm(res_analytical' .- res3b) < 1e-7 - @test norm(res_analytical' .- res3c) < 1e-7 - @test norm(res_analytical' .- res3d) < 1e-7 - @test norm(res_analytical' .- res3e) < 1e-7 - @test norm(res_analytical' .- res3f) < 1e-7 - @test norm(res_analytical' .- res3g) < 1e-7 - @test norm(res_analytical' .- res3h) < 1e-7 - - @info "oop checks" - function foop(u,p,t) - dx = p[1] + p[2]*u[1] - dy = p[3]*u[1] + p[4]*u[2] - [dx,dy] - end - proboop = SteadyStateProblem(foop,u0,p) - soloop = solve(proboop,DynamicSS(Rodas5(),reltol=1e-14,abstol=1e-14),reltol=1e-14,abstol=1e-14) - - - res4a = adjoint_sensitivities(soloop,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(),g,dg!) - res4b = adjoint_sensitivities(soloop,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(),g,nothing) - res4c = adjoint_sensitivities(soloop,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autodiff=false),g,nothing) - res4d = adjoint_sensitivities(soloop,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=TrackerVJP()),g,nothing) - res4e = adjoint_sensitivities(soloop,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=ReverseDiffVJP()),g,nothing) - res4f = adjoint_sensitivities(soloop,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autojacvec=ZygoteVJP()),g,nothing) - res4g = adjoint_sensitivities(soloop,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autodiff=false,autojacvec=false),g,nothing) - res4h = adjoint_sensitivities(soloop,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(autodiff=true,autojacvec=false),g,nothing) - - @test norm(res_analytical' .- res4a) < 1e-7 - @test norm(res_analytical' .- res4b) < 1e-7 - @test norm(res_analytical' .- res4c) < 1e-7 - @test norm(res_analytical' .- res4d) < 1e-7 - @test norm(res_analytical' .- res4e) < 1e-7 - @test norm(res_analytical' .- res4f) < 1e-7 - @test norm(res_analytical' .- res4g) < 1e-7 - @test norm(res_analytical' .- res4h) < 1e-7 - end - - @testset "for u0: (should be zero, steady state does not depend on initial condition)" begin - res5 = ForwardDiff.gradient(prob.u0) do u0 - tmp_prob = remake(prob,u0=u0) - sol = solve(tmp_prob, - SSRootfind(nlsolve = (f!,u0,abstol) -> (res=NLsolve.nlsolve(f!,u0,autodiff=:forward,method=:newton,iterations=Int(1e6),ftol=1e-14);res.zero) ) - ) - A = convert(Array,sol) - g(A,p,nothing) - end - @test abs(dot(res5,res5)) < 1e-7 - end -end - -using Zygote -@testset "concrete_solve derivatives steady state solver" begin - - function g1(u,p,t) - sum(u) - end - function g2(u,p,t) - sum((2.0.-u).^2)/2 - end - - u0 = zeros(2) - p = [2.0,-2.0,1.0,-4.0] - - @testset "iip" begin - function f!(du,u,p,t) - du[1] = p[1] + p[2]*u[1] - du[2] = p[3]*u[1] + p[4]*u[2] + function jac!(J, u, p, t) #df/dx + J[1, 1] = p[2] + J[2, 1] = p[3] + J[1, 2] = 0 + J[2, 2] = p[4] + nothing end - prob = SteadyStateProblem(f!,u0,p) - - sol = solve(prob,DynamicSS(Rodas5())) - res1 = adjoint_sensitivities(sol,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(),g1,nothing) - res2 = adjoint_sensitivities(sol,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(),g2,nothing) - - - dp1 = Zygote.gradient(p->sum(solve(prob,DynamicSS(Rodas5()),u0=u0,p=p,sensealg=SteadyStateAdjoint())),p) - dp2 = Zygote.gradient(p->sum((2.0.-solve(prob,DynamicSS(Rodas5()),u0=u0,p=p,sensealg=SteadyStateAdjoint())).^2)/2.0,p) - - dp1d = Zygote.gradient(p->sum(solve(prob,DynamicSS(Rodas5()),u0=u0,p=p)),p) - dp2d = Zygote.gradient(p->sum((2.0.-solve(prob,DynamicSS(Rodas5()),u0=u0,p=p)).^2)/2.0,p) - - @test res1 ≈ dp1[1] rtol=1e-12 - @test res2 ≈ dp2[1] rtol=1e-12 - @test res1 ≈ dp1d[1] rtol=1e-12 - @test res2 ≈ dp2d[1] rtol=1e-12 - res1 = Zygote.gradient(p->sum(Array(solve(prob,DynamicSS(Rodas5()),u0=u0,p=p,sensealg=SteadyStateAdjoint()))[1]),p) - dp1 = Zygote.gradient(p->sum(solve(prob,DynamicSS(Rodas5()),u0=u0,p=p,save_idxs=1:1,sensealg=SteadyStateAdjoint())),p) - dp2 = Zygote.gradient(p->solve(prob,DynamicSS(Rodas5()),u0=u0,p=p,save_idxs=1,sensealg=SteadyStateAdjoint())[1],p) - - dp1d = Zygote.gradient(p->sum(solve(prob,DynamicSS(Rodas5()),u0=u0,p=p,save_idxs=1:1)),p) - dp2d = Zygote.gradient(p->solve(prob,DynamicSS(Rodas5()),u0=u0,p=p,save_idxs=1)[1],p) - @test res1[1] ≈ dp1[1] rtol=1e-10 - @test res1[1] ≈ dp2[1] rtol=1e-10 - @test res1[1] ≈ dp1d[1] rtol=1e-10 - @test res1[1] ≈ dp2d[1] rtol=1e-10 - end + function paramjac!(fp, u, p, t) #df/dp + fp[1, 1] = 1 + fp[2, 1] = 0 + fp[1, 2] = u[1] + fp[2, 2] = 0 + fp[1, 3] = 0 + fp[2, 3] = u[1] + fp[1, 4] = 0 + fp[2, 4] = u[2] + nothing + end - @testset "oop" begin - function f(u,p,t) - dx = p[1] + p[2]*u[1] - dy = p[3]*u[1] + p[4]*u[2] - [dx,dy] + function dg!(out, u, p, t, i) + (out .= -2.0 .+ u) end - proboop = SteadyStateProblem(f,u0,p) + function g(u, p, t) + sum((2.0 .- u) .^ 2) / 2 + sum(p .^ 2) / 2 + end - soloop = solve(proboop,DynamicSS(Rodas5())) - res1oop = adjoint_sensitivities(soloop,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(),g1,nothing) - res2oop = adjoint_sensitivities(soloop,DynamicSS(Rodas5()),sensealg=SteadyStateAdjoint(),g2,nothing) + u0 = zeros(2) + p = [2.0, -2.0, 1.0, -4.0] + prob = SteadyStateProblem(f!, u0, p) + abstol = 1e-10 + @testset "for p" begin + println("Calculate adjoint sensitivities from Jacobians") + + sol_analytical = [-p[1] / p[2], p[1] * p[3] / (p[2] * p[4])] + + J = zeros(2, 2) + fp = zeros(2, 4) + gp = zeros(4) + gx = zeros(1, 2) + delg_delp = copy(p) + + jac!(J, sol_analytical, p, nothing) + dg!(vec(gx), sol_analytical, p, nothing, nothing) + paramjac!(fp, sol_analytical, p, nothing) + + lambda = J' \ gx' + res_analytical = delg_delp' - lambda' * fp # = -gx*inv(J)*fp + + @info "Expected result" sol_analytical, res_analytical, + delg_delp' - gx * inv(J) * fp + + @info "Calculate adjoint sensitivities from autodiff & numerical diff" + function G(p) + tmp_prob = remake(prob, u0 = convert.(eltype(p), prob.u0), p = p) + sol = solve(tmp_prob, + SSRootfind(nlsolve = (f!, u0, abstol) -> (res = NLsolve.nlsolve(f!, + u0, + autodiff = :forward, + method = :newton, + iterations = Int(1e6), + ftol = 1e-14); + res.zero))) + A = convert(Array, sol) + g(A, p, nothing) + end + res1 = ForwardDiff.gradient(G, p) + res2 = Calculus.gradient(G, p) + #@info res1, res2, res_analytical + + @test res1≈res_analytical' rtol=1e-7 + @test res2≈res_analytical' rtol=1e-7 + @test res1≈res2 rtol=1e-7 + + @info "Adjoint sensitivities" + + # with jac, param_jac + f1 = ODEFunction(f!; jac = jac!, paramjac = paramjac!) + prob1 = SteadyStateProblem(f1, u0, p) + sol1 = solve(prob1, DynamicSS(Rodas5(), reltol = 1e-14, abstol = 1e-14), + reltol = 1e-14, abstol = 1e-14) + + res1a = adjoint_sensitivities(sol1, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(), g, dg!) + res1b = adjoint_sensitivities(sol1, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(), g, nothing) + res1c = adjoint_sensitivities(sol1, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autodiff = false), g, + nothing) + res1d = adjoint_sensitivities(sol1, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = TrackerVJP()), + g, nothing) + res1e = adjoint_sensitivities(sol1, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = ReverseDiffVJP()), + g, nothing) + res1f = adjoint_sensitivities(sol1, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = ZygoteVJP()), + g, nothing) + res1g = adjoint_sensitivities(sol1, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autodiff = false, + autojacvec = false), g, + nothing) + res1h = adjoint_sensitivities(sol1, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = EnzymeVJP()), + g, nothing) + + # with jac, without param_jac + f2 = ODEFunction(f!; jac = jac!) + prob2 = SteadyStateProblem(f2, u0, p) + sol2 = solve(prob2, DynamicSS(Rodas5(), reltol = 1e-14, abstol = 1e-14), + reltol = 1e-14, abstol = 1e-14) + res2a = adjoint_sensitivities(sol2, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(), g, dg!) + res2b = adjoint_sensitivities(sol2, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(), g, nothing) + res2c = adjoint_sensitivities(sol2, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autodiff = false), g, + nothing) + res2d = adjoint_sensitivities(sol2, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = TrackerVJP()), + g, nothing) + res2e = adjoint_sensitivities(sol2, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = ReverseDiffVJP()), + g, nothing) + res2f = adjoint_sensitivities(sol2, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = ZygoteVJP()), + g, nothing) + res2g = adjoint_sensitivities(sol2, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autodiff = false, + autojacvec = false), g, + nothing) + res2h = adjoint_sensitivities(sol2, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = EnzymeVJP()), + g, nothing) + + # without jac, without param_jac + f3 = ODEFunction(f!) + prob3 = SteadyStateProblem(f3, u0, p) + sol3 = solve(prob3, DynamicSS(Rodas5(), reltol = 1e-14, abstol = 1e-14), + reltol = 1e-14, abstol = 1e-14) + res3a = adjoint_sensitivities(sol3, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(), g, dg!) + res3b = adjoint_sensitivities(sol3, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(), g, nothing) + res3c = adjoint_sensitivities(sol3, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autodiff = false), g, + nothing) + res3d = adjoint_sensitivities(sol3, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = TrackerVJP()), + g, nothing) + res3e = adjoint_sensitivities(sol3, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = ReverseDiffVJP()), + g, nothing) + res3f = adjoint_sensitivities(sol3, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = ZygoteVJP()), + g, nothing) + res3g = adjoint_sensitivities(sol3, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autodiff = false, + autojacvec = false), g, + nothing) + res3h = adjoint_sensitivities(sol3, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = EnzymeVJP()), + g, nothing) + + @test norm(res_analytical' .- res1a) < 1e-7 + @test norm(res_analytical' .- res1b) < 1e-7 + @test norm(res_analytical' .- res1c) < 1e-7 + @test norm(res_analytical' .- res1d) < 1e-7 + @test norm(res_analytical' .- res1e) < 1e-7 + @test norm(res_analytical' .- res1f) < 1e-7 + @test norm(res_analytical' .- res1g) < 1e-7 + @test norm(res_analytical' .- res1h) < 1e-7 + @test norm(res_analytical' .- res2a) < 1e-7 + @test norm(res_analytical' .- res2b) < 1e-7 + @test norm(res_analytical' .- res2c) < 1e-7 + @test norm(res_analytical' .- res2d) < 1e-7 + @test norm(res_analytical' .- res2e) < 1e-7 + @test norm(res_analytical' .- res2f) < 1e-7 + @test norm(res_analytical' .- res2g) < 1e-7 + @test norm(res_analytical' .- res2h) < 1e-7 + @test norm(res_analytical' .- res3a) < 1e-7 + @test norm(res_analytical' .- res3b) < 1e-7 + @test norm(res_analytical' .- res3c) < 1e-7 + @test norm(res_analytical' .- res3d) < 1e-7 + @test norm(res_analytical' .- res3e) < 1e-7 + @test norm(res_analytical' .- res3f) < 1e-7 + @test norm(res_analytical' .- res3g) < 1e-7 + @test norm(res_analytical' .- res3h) < 1e-7 + + @info "oop checks" + function foop(u, p, t) + dx = p[1] + p[2] * u[1] + dy = p[3] * u[1] + p[4] * u[2] + [dx, dy] + end + proboop = SteadyStateProblem(foop, u0, p) + soloop = solve(proboop, DynamicSS(Rodas5(), reltol = 1e-14, abstol = 1e-14), + reltol = 1e-14, abstol = 1e-14) + + res4a = adjoint_sensitivities(soloop, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(), g, dg!) + res4b = adjoint_sensitivities(soloop, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(), g, nothing) + res4c = adjoint_sensitivities(soloop, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autodiff = false), g, + nothing) + res4d = adjoint_sensitivities(soloop, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = TrackerVJP()), + g, nothing) + res4e = adjoint_sensitivities(soloop, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = ReverseDiffVJP()), + g, nothing) + res4f = adjoint_sensitivities(soloop, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autojacvec = ZygoteVJP()), + g, nothing) + res4g = adjoint_sensitivities(soloop, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autodiff = false, + autojacvec = false), g, + nothing) + res4h = adjoint_sensitivities(soloop, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(autodiff = true, + autojacvec = false), g, + nothing) + + @test norm(res_analytical' .- res4a) < 1e-7 + @test norm(res_analytical' .- res4b) < 1e-7 + @test norm(res_analytical' .- res4c) < 1e-7 + @test norm(res_analytical' .- res4d) < 1e-7 + @test norm(res_analytical' .- res4e) < 1e-7 + @test norm(res_analytical' .- res4f) < 1e-7 + @test norm(res_analytical' .- res4g) < 1e-7 + @test norm(res_analytical' .- res4h) < 1e-7 + end + @testset "for u0: (should be zero, steady state does not depend on initial condition)" begin + res5 = ForwardDiff.gradient(prob.u0) do u0 + tmp_prob = remake(prob, u0 = u0) + sol = solve(tmp_prob, + SSRootfind(nlsolve = (f!, u0, abstol) -> (res = NLsolve.nlsolve(f!, + u0, + autodiff = :forward, + method = :newton, + iterations = Int(1e6), + ftol = 1e-14); + res.zero))) + A = convert(Array, sol) + g(A, p, nothing) + end + @test abs(dot(res5, res5)) < 1e-7 + end +end - dp1oop = Zygote.gradient(p->sum(solve(proboop,DynamicSS(Rodas5()),u0=u0,p=p,sensealg=SteadyStateAdjoint())),p) - dp2oop = Zygote.gradient(p->sum((2.0.-solve(proboop,DynamicSS(Rodas5()),u0=u0,p=p,sensealg=SteadyStateAdjoint())).^2)/2.0,p) - dp1oopd = Zygote.gradient(p->sum(solve(proboop,DynamicSS(Rodas5()),u0=u0,p=p)),p) - dp2oopd = Zygote.gradient(p->sum((2.0.-solve(proboop,DynamicSS(Rodas5()),u0=u0,p=p)).^2)/2.0,p) +using Zygote +@testset "concrete_solve derivatives steady state solver" begin + function g1(u, p, t) + sum(u) + end - @test res1oop ≈ dp1oop[1] rtol=1e-12 - @test res2oop ≈ dp2oop[1] rtol=1e-12 - @test res1oop ≈ dp1oopd[1] rtol=1e-8 - @test res2oop ≈ dp2oopd[1] rtol=1e-8 + function g2(u, p, t) + sum((2.0 .- u) .^ 2) / 2 + end + u0 = zeros(2) + p = [2.0, -2.0, 1.0, -4.0] + + @testset "iip" begin + function f!(du, u, p, t) + du[1] = p[1] + p[2] * u[1] + du[2] = p[3] * u[1] + p[4] * u[2] + end + prob = SteadyStateProblem(f!, u0, p) + + sol = solve(prob, DynamicSS(Rodas5())) + res1 = adjoint_sensitivities(sol, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(), g1, nothing) + res2 = adjoint_sensitivities(sol, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(), g2, nothing) + + dp1 = Zygote.gradient(p -> sum(solve(prob, DynamicSS(Rodas5()), u0 = u0, p = p, + sensealg = SteadyStateAdjoint())), p) + dp2 = Zygote.gradient(p -> sum((2.0 .- + solve(prob, DynamicSS(Rodas5()), u0 = u0, p = p, + sensealg = SteadyStateAdjoint())) .^ 2) / 2.0, + p) + + dp1d = Zygote.gradient(p -> sum(solve(prob, DynamicSS(Rodas5()), u0 = u0, p = p)), + p) + dp2d = Zygote.gradient(p -> sum((2.0 .- + solve(prob, DynamicSS(Rodas5()), u0 = u0, p = p)) .^ + 2) / 2.0, p) + + @test res1≈dp1[1] rtol=1e-12 + @test res2≈dp2[1] rtol=1e-12 + @test res1≈dp1d[1] rtol=1e-12 + @test res2≈dp2d[1] rtol=1e-12 + + res1 = Zygote.gradient(p -> sum(Array(solve(prob, DynamicSS(Rodas5()), u0 = u0, + p = p, sensealg = SteadyStateAdjoint()))[1]), + p) + dp1 = Zygote.gradient(p -> sum(solve(prob, DynamicSS(Rodas5()), u0 = u0, p = p, + save_idxs = 1:1, + sensealg = SteadyStateAdjoint())), p) + dp2 = Zygote.gradient(p -> solve(prob, DynamicSS(Rodas5()), u0 = u0, p = p, + save_idxs = 1, sensealg = SteadyStateAdjoint())[1], + p) + + dp1d = Zygote.gradient(p -> sum(solve(prob, DynamicSS(Rodas5()), u0 = u0, p = p, + save_idxs = 1:1)), p) + dp2d = Zygote.gradient(p -> solve(prob, DynamicSS(Rodas5()), u0 = u0, p = p, + save_idxs = 1)[1], p) + @test res1[1]≈dp1[1] rtol=1e-10 + @test res1[1]≈dp2[1] rtol=1e-10 + @test res1[1]≈dp1d[1] rtol=1e-10 + @test res1[1]≈dp2d[1] rtol=1e-10 + end - res1oop = Zygote.gradient(p->sum(Array(solve(proboop,DynamicSS(Rodas5()),u0=u0,p=p,sensealg=SteadyStateAdjoint()))[1]),p) - dp1oop = Zygote.gradient(p->sum(solve(proboop,DynamicSS(Rodas5()),u0=u0,p=p,save_idxs=1:1,sensealg=SteadyStateAdjoint())),p) - dp2oop = Zygote.gradient(p->solve(proboop,DynamicSS(Rodas5()),u0=u0,p=p,save_idxs=1,sensealg=SteadyStateAdjoint())[1],p) - dp1oopd = Zygote.gradient(p->sum(solve(proboop,DynamicSS(Rodas5()),u0=u0,p=p,save_idxs=1:1)),p) - dp2oopd = Zygote.gradient(p->solve(proboop,DynamicSS(Rodas5()),u0=u0,p=p,save_idxs=1)[1],p) - @test res1oop[1] ≈ dp1oop[1] rtol=1e-10 - @test res1oop[1] ≈ dp2oop[1] rtol=1e-10 - end + @testset "oop" begin + function f(u, p, t) + dx = p[1] + p[2] * u[1] + dy = p[3] * u[1] + p[4] * u[2] + [dx, dy] + end + proboop = SteadyStateProblem(f, u0, p) + + soloop = solve(proboop, DynamicSS(Rodas5())) + res1oop = adjoint_sensitivities(soloop, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(), g1, nothing) + res2oop = adjoint_sensitivities(soloop, DynamicSS(Rodas5()), + sensealg = SteadyStateAdjoint(), g2, nothing) + + dp1oop = Zygote.gradient(p -> sum(solve(proboop, DynamicSS(Rodas5()), u0 = u0, + p = p, sensealg = SteadyStateAdjoint())), p) + dp2oop = Zygote.gradient(p -> sum((2.0 .- + solve(proboop, DynamicSS(Rodas5()), u0 = u0, + p = p, sensealg = SteadyStateAdjoint())) .^ + 2) / 2.0, p) + dp1oopd = Zygote.gradient(p -> sum(solve(proboop, DynamicSS(Rodas5()), u0 = u0, + p = p)), p) + dp2oopd = Zygote.gradient(p -> sum((2.0 .- + solve(proboop, DynamicSS(Rodas5()), u0 = u0, + p = p)) .^ 2) / 2.0, p) + + @test res1oop≈dp1oop[1] rtol=1e-12 + @test res2oop≈dp2oop[1] rtol=1e-12 + @test res1oop≈dp1oopd[1] rtol=1e-8 + @test res2oop≈dp2oopd[1] rtol=1e-8 + + res1oop = Zygote.gradient(p -> sum(Array(solve(proboop, DynamicSS(Rodas5()), + u0 = u0, p = p, + sensealg = SteadyStateAdjoint()))[1]), + p) + dp1oop = Zygote.gradient(p -> sum(solve(proboop, DynamicSS(Rodas5()), u0 = u0, + p = p, save_idxs = 1:1, + sensealg = SteadyStateAdjoint())), p) + dp2oop = Zygote.gradient(p -> solve(proboop, DynamicSS(Rodas5()), u0 = u0, p = p, + save_idxs = 1, sensealg = SteadyStateAdjoint())[1], + p) + dp1oopd = Zygote.gradient(p -> sum(solve(proboop, DynamicSS(Rodas5()), u0 = u0, + p = p, save_idxs = 1:1)), p) + dp2oopd = Zygote.gradient(p -> solve(proboop, DynamicSS(Rodas5()), u0 = u0, p = p, + save_idxs = 1)[1], p) + @test res1oop[1]≈dp1oop[1] rtol=1e-10 + @test res1oop[1]≈dp2oop[1] rtol=1e-10 + end end - diff --git a/test/stiff_adjoints.jl b/test/stiff_adjoints.jl index 3b4948926..52c85ed67 100644 --- a/test/stiff_adjoints.jl +++ b/test/stiff_adjoints.jl @@ -1,205 +1,213 @@ -using Zygote, SciMLSensitivity -println("Starting tests") -using OrdinaryDiffEq, ForwardDiff, Test - -function lotka_volterra(u, p, t) - x, y = u - α, β, δ, γ = p - [α * x - β * x * y,-δ * y + γ * x * y] -end - -function lotka_volterra(du, u, p, t) - x, y = u - α, β, δ, γ = p - du[1] = dx = α * x - β * x * y - du[2] = dy = -δ * y + γ * x * y -end - -u0 = [1.0,1.0]; -tspan = (0.0,10.0); -p0 = [1.5,1.0,3.0,1.0]; -prob0 = ODEProblem(lotka_volterra,u0,tspan,p0); -# Solve the ODE and collect solutions at fixed intervals -target_data = solve(prob0,RadauIIA5(), saveat = 0:0.5:10.0); - -loss_function = function(p) - prob = remake(prob0;u0=convert.(eltype(p),prob0.u0),p=p) - prediction = solve(prob, RadauIIA5(); saveat = 0.0:0.5:10.0,abstol=1e-10,reltol=1e-10) - - tmpdata=prediction[[1,2],:]; - tdata=target_data[[1,2],:]; - - # Calculate squared error - return sum(abs2,tmpdata-tdata) -end -p=[1.5,1.2,1.4,1.6]; -fdgrad = ForwardDiff.gradient(loss_function,p) -rdgrad = Zygote.gradient(loss_function,p)[1] - -@test fdgrad ≈ rdgrad rtol=1e-5 - -loss_function = function(p) - prob = remake(prob0;u0=convert.(eltype(p),prob0.u0),p=p) - prediction = solve(prob, TRBDF2(); saveat = 0.0:0.5:10.0,abstol=1e-10,reltol=1e-10) - - tmpdata=prediction[[1,2],:]; - tdata=target_data[[1,2],:]; - - # Calculate squared error - return sum(abs2,tmpdata-tdata) -end - -rdgrad = Zygote.gradient(loss_function,p)[1] -@test fdgrad ≈ rdgrad rtol=1e-3 - -loss_function = function(p) - prob = remake(prob0;u0=convert.(eltype(p),prob0.u0),p=p) - prediction = solve(prob, Rosenbrock23(); saveat = 0.0:0.5:10.0,abstol=1e-8,reltol=1e-8) - - tmpdata=prediction[[1,2],:]; - tdata=target_data[[1,2],:]; - - # Calculate squared error - return sum(abs2,tmpdata-tdata) -end - -rdgrad = Zygote.gradient(loss_function,p)[1] -@test fdgrad ≈ rdgrad rtol=1e-3 - -loss_function = function(p) - prob = remake(prob0;u0=convert.(eltype(p),prob0.u0),p=p) - prediction = solve(prob, Rodas5(); saveat = 0.0:0.5:10.0,abstol=1e-8,reltol=1e-8) - - tmpdata=prediction[[1,2],:]; - tdata=target_data[[1,2],:]; - - # Calculate squared error - return sum(abs2,tmpdata-tdata) -end - -rdgrad = Zygote.gradient(loss_function,p)[1] -@test fdgrad ≈ rdgrad rtol=1e-3 - -### OOP - -prob0_oop = ODEProblem{false}(lotka_volterra,u0,tspan,p0); -# Solve the ODE and collect solutions at fixed intervals -target_data = solve(prob0,RadauIIA5(), saveat = 0:0.5:10.0); - -loss_function = function(p) - prob = remake(prob0_oop;u0=convert.(eltype(p),prob0.u0),p=p) - prediction = solve(prob, RadauIIA5(); saveat = 0.0:0.5:10.0,abstol=1e-10,reltol=1e-10) - - tmpdata=prediction[[1,2],:]; - tdata=target_data[[1,2],:]; - - # Calculate squared error - return sum(abs2,tmpdata-tdata) -end -p=[1.5,1.2,1.4,1.6]; - -fdgrad = ForwardDiff.gradient(loss_function,p) -rdgrad = Zygote.gradient(loss_function,p)[1] - -@test fdgrad ≈ rdgrad rtol=1e-4 - -loss_function = function(p) - prob = remake(prob0_oop;u0=convert.(eltype(p),prob0.u0),p=p) - prediction = solve(prob, TRBDF2(); saveat = 0.0:0.5:10.0,abstol=1e-10,reltol=1e-10) - - tmpdata=prediction[[1,2],:]; - tdata=target_data[[1,2],:]; - - # Calculate squared error - return sum(abs2,tmpdata-tdata) -end - -rdgrad = Zygote.gradient(loss_function,p)[1] -@test fdgrad ≈ rdgrad rtol=1e-3 - -loss_function = function(p) - prob = remake(prob0_oop;u0=convert.(eltype(p),prob0.u0),p=p) - prediction = solve(prob, Rosenbrock23(); saveat = 0.0:0.5:10.0,abstol=1e-8,reltol=1e-8) - - tmpdata=prediction[[1,2],:]; - tdata=target_data[[1,2],:]; - - # Calculate squared error - return sum(abs2,tmpdata-tdata) -end - -rdgrad = Zygote.gradient(loss_function,p)[1] -@test fdgrad ≈ rdgrad rtol=1e-4 - -loss_function = function(p) - prob = remake(prob0_oop;u0=convert.(eltype(p),prob0.u0),p=p) - prediction = solve(prob, Rodas5(); saveat = 0.0:0.5:10.0,abstol=1e-12,reltol=1e-12) - - tmpdata=prediction[[1,2],:]; - tdata=target_data[[1,2],:]; - - # Calculate squared error - return sum(abs2,tmpdata-tdata) -end - -rdgrad = Zygote.gradient(loss_function,p)[1] -@test fdgrad ≈ rdgrad rtol=1e-3 - -# all implicit solvers -solvers = [ - # SDIRK Methods (ok) - ImplicitEuler(), - TRBDF2(), - KenCarp4(), - # Fully-Implicit Runge-Kutta Methods (FIRK) - RadauIIA5(), - # Fully-Implicit Runge-Kutta Methods (FIRK) - #PDIRK44(), - # Rosenbrock Methods - Rodas3(), - Rodas4(), - Rodas5(), - # Rosenbrock-W Methods - Rosenbrock23(), - ROS34PW3(), - # Stabilized Explicit Methods (ok) - ROCK2(), - ROCK4(), - RKC(), - # SERK2v2(), not defined? - ESERK5()]; - -p = rand(3) - -function dudt(u,p,t) - u .* p -end - -for solver in solvers - function loss(p) - prob = ODEProblem(dudt, [3.0, 2.0, 1.0], (0.0, 1.0), p) - sol = solve(prob, solver, dt=0.01, saveat=0.1, abstol=1e-5, reltol=1e-5) - sum(abs2, Array(sol)) - end - - println(DiffEqBase.parameterless_type(solver)) - loss(p) - dp = Zygote.gradient(loss, p)[1] - - function loss(p, sensealg) - prob = ODEProblem(dudt, [3.0, 2.0, 1.0], (0.0, 1.0), p) - sol = solve(prob, solver, dt=0.01, saveat=0.1, sensealg=sensealg, abstol=1e-5, reltol=1e-5) - sum(abs2, Array(sol)) - end - - dp1 = Zygote.gradient(p -> loss(p, InterpolatingAdjoint()), p)[1] - @test dp ≈ dp1 rtol = 1e-2 - dp1 = Zygote.gradient(p -> loss(p, BacksolveAdjoint()), p)[1] - @test dp ≈ dp1 rtol = 1e-2 - dp1 = Zygote.gradient(p -> loss(p, QuadratureAdjoint()), p)[1] - @test dp ≈ dp1 rtol = 1e-2 - dp1 = Zygote.gradient(p -> loss(p, ForwardDiffSensitivity()), p)[1] - @test dp ≈ dp1 rtol = 1e-2 - dp1 = @test_broken Zygote.gradient(p -> loss(p, ReverseDiffAdjoint()), p)[1] - @test_broken dp ≈ dp1 rtol = 1e-2 -end +using Zygote, SciMLSensitivity +println("Starting tests") +using OrdinaryDiffEq, ForwardDiff, Test + +function lotka_volterra(u, p, t) + x, y = u + α, β, δ, γ = p + [α * x - β * x * y, -δ * y + γ * x * y] +end + +function lotka_volterra(du, u, p, t) + x, y = u + α, β, δ, γ = p + du[1] = dx = α * x - β * x * y + du[2] = dy = -δ * y + γ * x * y +end + +u0 = [1.0, 1.0]; +tspan = (0.0, 10.0); +p0 = [1.5, 1.0, 3.0, 1.0]; +prob0 = ODEProblem(lotka_volterra, u0, tspan, p0); +# Solve the ODE and collect solutions at fixed intervals +target_data = solve(prob0, RadauIIA5(), saveat = 0:0.5:10.0); + +loss_function = function (p) + prob = remake(prob0; u0 = convert.(eltype(p), prob0.u0), p = p) + prediction = solve(prob, RadauIIA5(); saveat = 0.0:0.5:10.0, abstol = 1e-10, + reltol = 1e-10) + + tmpdata = prediction[[1, 2], :] + tdata = target_data[[1, 2], :] + + # Calculate squared error + return sum(abs2, tmpdata - tdata) +end +p = [1.5, 1.2, 1.4, 1.6]; +fdgrad = ForwardDiff.gradient(loss_function, p) +rdgrad = Zygote.gradient(loss_function, p)[1] + +@test fdgrad≈rdgrad rtol=1e-5 + +loss_function = function (p) + prob = remake(prob0; u0 = convert.(eltype(p), prob0.u0), p = p) + prediction = solve(prob, TRBDF2(); saveat = 0.0:0.5:10.0, abstol = 1e-10, + reltol = 1e-10) + + tmpdata = prediction[[1, 2], :] + tdata = target_data[[1, 2], :] + + # Calculate squared error + return sum(abs2, tmpdata - tdata) +end + +rdgrad = Zygote.gradient(loss_function, p)[1] +@test fdgrad≈rdgrad rtol=1e-3 + +loss_function = function (p) + prob = remake(prob0; u0 = convert.(eltype(p), prob0.u0), p = p) + prediction = solve(prob, Rosenbrock23(); saveat = 0.0:0.5:10.0, abstol = 1e-8, + reltol = 1e-8) + + tmpdata = prediction[[1, 2], :] + tdata = target_data[[1, 2], :] + + # Calculate squared error + return sum(abs2, tmpdata - tdata) +end + +rdgrad = Zygote.gradient(loss_function, p)[1] +@test fdgrad≈rdgrad rtol=1e-3 + +loss_function = function (p) + prob = remake(prob0; u0 = convert.(eltype(p), prob0.u0), p = p) + prediction = solve(prob, Rodas5(); saveat = 0.0:0.5:10.0, abstol = 1e-8, reltol = 1e-8) + + tmpdata = prediction[[1, 2], :] + tdata = target_data[[1, 2], :] + + # Calculate squared error + return sum(abs2, tmpdata - tdata) +end + +rdgrad = Zygote.gradient(loss_function, p)[1] +@test fdgrad≈rdgrad rtol=1e-3 + +### OOP + +prob0_oop = ODEProblem{false}(lotka_volterra, u0, tspan, p0); +# Solve the ODE and collect solutions at fixed intervals +target_data = solve(prob0, RadauIIA5(), saveat = 0:0.5:10.0); + +loss_function = function (p) + prob = remake(prob0_oop; u0 = convert.(eltype(p), prob0.u0), p = p) + prediction = solve(prob, RadauIIA5(); saveat = 0.0:0.5:10.0, abstol = 1e-10, + reltol = 1e-10) + + tmpdata = prediction[[1, 2], :] + tdata = target_data[[1, 2], :] + + # Calculate squared error + return sum(abs2, tmpdata - tdata) +end +p = [1.5, 1.2, 1.4, 1.6]; + +fdgrad = ForwardDiff.gradient(loss_function, p) +rdgrad = Zygote.gradient(loss_function, p)[1] + +@test fdgrad≈rdgrad rtol=1e-4 + +loss_function = function (p) + prob = remake(prob0_oop; u0 = convert.(eltype(p), prob0.u0), p = p) + prediction = solve(prob, TRBDF2(); saveat = 0.0:0.5:10.0, abstol = 1e-10, + reltol = 1e-10) + + tmpdata = prediction[[1, 2], :] + tdata = target_data[[1, 2], :] + + # Calculate squared error + return sum(abs2, tmpdata - tdata) +end + +rdgrad = Zygote.gradient(loss_function, p)[1] +@test fdgrad≈rdgrad rtol=1e-3 + +loss_function = function (p) + prob = remake(prob0_oop; u0 = convert.(eltype(p), prob0.u0), p = p) + prediction = solve(prob, Rosenbrock23(); saveat = 0.0:0.5:10.0, abstol = 1e-8, + reltol = 1e-8) + + tmpdata = prediction[[1, 2], :] + tdata = target_data[[1, 2], :] + + # Calculate squared error + return sum(abs2, tmpdata - tdata) +end + +rdgrad = Zygote.gradient(loss_function, p)[1] +@test fdgrad≈rdgrad rtol=1e-4 + +loss_function = function (p) + prob = remake(prob0_oop; u0 = convert.(eltype(p), prob0.u0), p = p) + prediction = solve(prob, Rodas5(); saveat = 0.0:0.5:10.0, abstol = 1e-12, + reltol = 1e-12) + + tmpdata = prediction[[1, 2], :] + tdata = target_data[[1, 2], :] + + # Calculate squared error + return sum(abs2, tmpdata - tdata) +end + +rdgrad = Zygote.gradient(loss_function, p)[1] +@test fdgrad≈rdgrad rtol=1e-3 + +# all implicit solvers +solvers = [ + # SDIRK Methods (ok) + ImplicitEuler(), + TRBDF2(), + KenCarp4(), + # Fully-Implicit Runge-Kutta Methods (FIRK) + RadauIIA5(), + # Fully-Implicit Runge-Kutta Methods (FIRK) + #PDIRK44(), + # Rosenbrock Methods + Rodas3(), + Rodas4(), + Rodas5(), + # Rosenbrock-W Methods + Rosenbrock23(), + ROS34PW3(), + # Stabilized Explicit Methods (ok) + ROCK2(), + ROCK4(), + RKC(), + # SERK2v2(), not defined? + ESERK5()]; + +p = rand(3) + +function dudt(u, p, t) + u .* p +end + +for solver in solvers + function loss(p) + prob = ODEProblem(dudt, [3.0, 2.0, 1.0], (0.0, 1.0), p) + sol = solve(prob, solver, dt = 0.01, saveat = 0.1, abstol = 1e-5, reltol = 1e-5) + sum(abs2, Array(sol)) + end + + println(DiffEqBase.parameterless_type(solver)) + loss(p) + dp = Zygote.gradient(loss, p)[1] + + function loss(p, sensealg) + prob = ODEProblem(dudt, [3.0, 2.0, 1.0], (0.0, 1.0), p) + sol = solve(prob, solver, dt = 0.01, saveat = 0.1, sensealg = sensealg, + abstol = 1e-5, reltol = 1e-5) + sum(abs2, Array(sol)) + end + + dp1 = Zygote.gradient(p -> loss(p, InterpolatingAdjoint()), p)[1] + @test dp≈dp1 rtol=1e-2 + dp1 = Zygote.gradient(p -> loss(p, BacksolveAdjoint()), p)[1] + @test dp≈dp1 rtol=1e-2 + dp1 = Zygote.gradient(p -> loss(p, QuadratureAdjoint()), p)[1] + @test dp≈dp1 rtol=1e-2 + dp1 = Zygote.gradient(p -> loss(p, ForwardDiffSensitivity()), p)[1] + @test dp≈dp1 rtol=1e-2 + dp1 = @test_broken Zygote.gradient(p -> loss(p, ReverseDiffAdjoint()), p)[1] + @test_broken dp≈dp1 rtol=1e-2 +end diff --git a/test/time_type_mixing.jl b/test/time_type_mixing.jl index 4f08215bf..df7a13727 100644 --- a/test/time_type_mixing.jl +++ b/test/time_type_mixing.jl @@ -1,6 +1,6 @@ using OrdinaryDiffEq, Zygote, SciMLSensitivity, Test -p_model = [1f0] +p_model = [1.0f0] u0 = Float32.([0.0]) @@ -8,18 +8,18 @@ function dudt(du, u, p, t) du[1] = p[1] end -prob = ODEProblem(dudt,u0,(0f0,99.9f0)) +prob = ODEProblem(dudt, u0, (0.0f0, 99.9f0)) function predict_neuralode(p) - _prob = remake(prob,p=p) - Array(solve(_prob,Tsit5(), saveat=0.1)) + _prob = remake(prob, p = p) + Array(solve(_prob, Tsit5(), saveat = 0.1)) end -loss(p) = sum(abs2,predict_neuralode(p))/length(p) +loss(p) = sum(abs2, predict_neuralode(p)) / length(p) p_model_ini = copy(p_model) -@test !iszero(Zygote.gradient(loss,p_model_ini)[1]) +@test !iszero(Zygote.gradient(loss, p_model_ini)[1]) ## https://github.com/SciML/SciMLSensitivity.jl/issues/675 @@ -29,15 +29,15 @@ p = [-0.1 2.0; -2.0 -0.1] datasize = 30 # Number of data points tspan = (0.0f0, 1.5f0) # Time range # tsteps = range(tspan[1], tspan[2], length = datasize) # Split time range into equal steps for each data point -tsteps = (rand(datasize) .* (tspan[2] - tspan[1]) .+ tspan[1]) |> sort +tsteps = (rand(datasize) .* (tspan[2] - tspan[1]) .+ tspan[1]) |> sort -function f(du,u,p,t) - du .= p*u +function f(du, u, p, t) + du .= p * u end function loss(p) - prob = ODEProblem(f,u0,tspan,p) - sol = solve(prob,Tsit5(),saveat=tsteps,sensealg=InterpolatingAdjoint()) + prob = ODEProblem(f, u0, tspan, p) + sol = solve(prob, Tsit5(), saveat = tsteps, sensealg = InterpolatingAdjoint()) sum(sol) end