From 90137a70423d9e218fb89c4efe31e5dce09e7506 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 5 Jul 2024 19:11:00 +0000 Subject: [PATCH 1/2] chore: add scimlstructures path to default VJP for SDEProblem --- src/concrete_solve.jl | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 2cf062fdb..ca693c47e 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -79,12 +79,21 @@ function automatic_sensealg_choice( prob::Union{SciMLBase.AbstractODEProblem, SciMLBase.AbstractSDEProblem}, u0, p, verbose, repack) + + if p === nothing || p isa SciMLBase.NullParameters + tunables, repack = p, identity + elseif isscimlstructure(p) + tunables, repack, _ = canonicalize(Tunable(), p) + else + throw(SciMLStructuresCompatibilityError()) + end + default_sensealg = if p !== DiffEqBase.NullParameters() && !(eltype(u0) <: ForwardDiff.Dual) && !(eltype(p) <: ForwardDiff.Dual) && !(eltype(u0) <: Complex) && !(eltype(p) <: Complex) && - length(u0) + length(p) <= 100 + length(u0) + length(tunables) <= 100 ForwardDiffSensitivity() elseif u0 isa GPUArraysCore.AbstractGPUArray || !DiffEqBase.isinplace(prob) # only Zygote is GPU compatible and fast @@ -124,7 +133,7 @@ function automatic_sensealg_choice( ReverseDiff.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0) else ReverseDiff.gradient( - (u, _p) -> sum(prob.f(u, repack(_p), prob.tspan[1])), u0, p) + (u, _p) -> sum(prob.f(u, repack(_p), prob.tspan[1])), u0, tunables) end ReverseDiffVJP() catch e @@ -151,8 +160,8 @@ function automatic_sensealg_choice( end tmp1 = back(λ) else - _dy, back = Tracker.forward(y, p) do u, p - vec(f(u, p, t)) + _dy, back = Tracker.forward(y, tunables) do u, tunables + vec(f(u, repack(tunables), t)) end tmp1, tmp2 = back(λ) end From 94e212a420547ee948ac56908cadac56a76d232d Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 5 Jul 2024 19:13:22 +0000 Subject: [PATCH 2/2] chore: format --- src/concrete_solve.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index ca693c47e..e11dac844 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -79,7 +79,6 @@ function automatic_sensealg_choice( prob::Union{SciMLBase.AbstractODEProblem, SciMLBase.AbstractSDEProblem}, u0, p, verbose, repack) - if p === nothing || p isa SciMLBase.NullParameters tunables, repack = p, identity elseif isscimlstructure(p) @@ -161,7 +160,7 @@ function automatic_sensealg_choice( tmp1 = back(λ) else _dy, back = Tracker.forward(y, tunables) do u, tunables - vec(f(u, repack(tunables), t)) + vec(f(u, repack(tunables), t)) end tmp1, tmp2 = back(λ) end