From ec8603fb278fa1fd52a9b5b7010ae943ec25fb5e Mon Sep 17 00:00:00 2001 From: Alban Gossard Date: Thu, 17 Jul 2025 15:36:30 +0200 Subject: [PATCH 1/3] add repack/canonicalize in vec_pjac! to support SciMLStructs --- src/gauss_adjoint.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 076809986..744a47d6e 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -494,13 +494,14 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) Enzyme.remake_zero!(tmp3) Enzyme.remake_zero!(out) + dp = isscimlstructure(p) ? repack(out) : out if SciMLBase.isinplace(sol.prob.f) Enzyme.remake_zero!(tmp6) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), - Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t)) + Enzyme.Const(y), Enzyme.Duplicated(p, dp), Enzyme.Const(t)) else function g(du, u, p, t) du .= f(u, p, t) @@ -510,7 +511,10 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(g, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), - Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t)) + Enzyme.Const(y), Enzyme.Duplicated(p, dp), Enzyme.Const(t)) + end + if isscimlstructure(p) + out .+= canonicalize(Tunable(), dp)[1] end elseif sensealg.autojacvec isa MooncakeVJP _, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ) From eac3d53a95453a6be2d188d80defec431f8f0854 Mon Sep 17 00:00:00 2001 From: Alban Gossard Date: Thu, 17 Jul 2025 22:16:38 +0200 Subject: [PATCH 2/3] fix typo in vec_pjac! of GaussAdjoint for SciMLStructures --- src/gauss_adjoint.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 744a47d6e..3723dd1b5 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -514,7 +514,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) Enzyme.Const(y), Enzyme.Duplicated(p, dp), Enzyme.Const(t)) end if isscimlstructure(p) - out .+= canonicalize(Tunable(), dp)[1] + out .= canonicalize(Tunable(), dp)[1] end elseif sensealg.autojacvec isa MooncakeVJP _, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ) From 424495d9e47c4a28f3edceeeedd5a7ae8b2a66d6 Mon Sep 17 00:00:00 2001 From: Alban Gossard Date: Thu, 17 Jul 2025 22:47:43 +0200 Subject: [PATCH 3/3] add test of GaussAdjoint with EnzymeVJP and SciMLStructs --- test/scimlstructures_interface.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/scimlstructures_interface.jl b/test/scimlstructures_interface.jl index 09a61609e..a93a259cd 100644 --- a/test/scimlstructures_interface.jl +++ b/test/scimlstructures_interface.jl @@ -158,4 +158,5 @@ end run_diff(initialize()) @test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint())[1].ps) -@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint(autojacvec=false))[1].ps) \ No newline at end of file +@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint(autojacvec=false))[1].ps) +@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint(autojacvec=EnzymeVJP()))[1].ps) \ No newline at end of file