diff --git a/Project.toml b/Project.toml index b01239d091..97d9fc5729 100644 --- a/Project.toml +++ b/Project.toml @@ -57,10 +57,12 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [weakdeps] BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" [extensions] MTKBifurcationKitExt = "BifurcationKit" +MTKChainRulesCoreExt = "ChainRulesCore" MTKDeepDiffsExt = "DeepDiffs" [compat] diff --git a/ext/MTKChainRulesCoreExt.jl b/ext/MTKChainRulesCoreExt.jl new file mode 100644 index 0000000000..6350b34b89 --- /dev/null +++ b/ext/MTKChainRulesCoreExt.jl @@ -0,0 +1,14 @@ +module MTKChainRulesCoreExt + +import ModelingToolkit as MTK +import ChainRulesCore +import ChainRulesCore: NoTangent + +function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...) + function mtp_pullback(dt) + (NoTangent(), dt.tunable[1:length(tunables)], ntuple(_ -> NoTangent(), length(args))...) + end + MTK.MTKParameters(tunables, args...), mtp_pullback +end + +end diff --git a/test/extensions/Project.toml b/test/extensions/Project.toml index 6b3cfb57c2..611ba7e1cf 100644 --- a/test/extensions/Project.toml +++ b/test/extensions/Project.toml @@ -1,3 +1,8 @@ [deps] BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/extensions/ad.jl b/test/extensions/ad.jl new file mode 100644 index 0000000000..2628a6ecd9 --- /dev/null +++ b/test/extensions/ad.jl @@ -0,0 +1,31 @@ +using ModelingToolkit +using ModelingToolkit: t_nounits as t, D_nounits as D +using Zygote +using SymbolicIndexingInterface +using SciMLStructures +using OrdinaryDiffEq +using SciMLSensitivity + +@variables x(t)[1:3] y(t) +@parameters p[1:3, 1:3] q +eqs = [ + D(x) ~ p * x + D(y) ~ sum(p) + q * y +] +u0 = [x => zeros(3), + y => 1.] +ps = [p => zeros(3, 3), + q => 1.] +tspan = (0., 10.) +@mtkbuild sys = ODESystem(eqs, t) +prob = ODEProblem(sys, u0, tspan, ps) +sol = solve(prob, Tsit5()) + +mtkparams = parameter_values(prob) +new_p = rand(10) +gs = gradient(new_p) do new_p + new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p) + new_prob = remake(prob, p = new_params) + new_sol = solve(new_prob, Tsit5()) + sum(new_sol) +end diff --git a/test/runtests.jl b/test/runtests.jl index 611081bb20..69003ee511 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -104,5 +104,6 @@ end if GROUP == "All" || GROUP == "Extensions" activate_extensions_env() @safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl") + @safetestset "Auto Differentiation Test" include("extensions/ad.jl") end end