Skip to content

Commit 93908be

Browse files
DhairyaLGandhiChrisRackauckas
authored andcommitted
test: add a simple test for AD
1 parent 60235e3 commit 93908be

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

test/extensions/Project.toml

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
[deps]
22
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
33
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
4+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
5+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
6+
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
7+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
8+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/extensions/ad.jl

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using ModelingToolkit
2+
using ModelingToolkit: t_nounits as t, D_nounits as D
3+
using Zygote
4+
using SymbolicIndexingInterface
5+
using SciMLStructures
6+
using OrdinaryDiffEq
7+
using SciMLSensitivity
8+
9+
@variables x(t)[1:3] y(t)
10+
@parameters p[1:3, 1:3] q
11+
eqs = [
12+
D(x) ~ p * x
13+
D(y) ~ sum(p) + q * y
14+
]
15+
u0 = [x => zeros(3),
16+
y => 1.]
17+
ps = [p => zeros(3, 3),
18+
q => 1.]
19+
tspan = (0., 10.)
20+
@mtkbuild sys = ODESystem(eqs, t)
21+
prob = ODEProblem(sys, u0, tspan, ps)
22+
sol = solve(prob, Tsit5())
23+
24+
mtkparams = parameter_values(prob)
25+
new_p = rand(10)
26+
gs = gradient(new_p) do new_p
27+
new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p)
28+
new_prob = remake(prob, p = new_params)
29+
new_sol = solve(new_prob, Tsit5())
30+
sum(new_sol)
31+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,6 @@ end
104104
if GROUP == "All" || GROUP == "Extensions"
105105
activate_extensions_env()
106106
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
107+
@safetestset "Auto Differentiation Test" include("extensions/ad.jl")
107108
end
108109
end

0 commit comments

Comments
 (0)