diff --git a/Project.toml b/Project.toml index 43c52dea1a..c379843e25 100644 --- a/Project.toml +++ b/Project.toml @@ -68,14 +68,13 @@ SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [compat] -ADTypes = "0.2, 1" +ADTypes = "1.13" Adapt = "3.0, 4" ArrayInterface = "7" DataStructures = "0.18" @@ -138,7 +137,6 @@ SciMLOperators = "0.3" SciMLStructures = "1" SimpleNonlinearSolve = "1, 2" SimpleUnPack = "1" -SparseDiffTools = "2" Static = "0.8, 1" StaticArrayInterface = "1.2" StaticArrays = "1.0" @@ -153,6 +151,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" @@ -167,6 +166,8 @@ RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" +SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" @@ -174,4 +175,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "StructArrays", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve", "RecursiveFactorization"] +test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "StructArrays", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve", "RecursiveFactorization", "Enzyme", "SparseConnectivityTracer", "SparseMatrixColorings"] diff --git a/lib/OrdinaryDiffEqBDF/Project.toml b/lib/OrdinaryDiffEqBDF/Project.toml index 222253884a..91db0b15ef 100644 --- a/lib/OrdinaryDiffEqBDF/Project.toml +++ b/lib/OrdinaryDiffEqBDF/Project.toml @@ -50,7 +50,9 @@ julia = "1.10" [extras] DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" @@ -58,4 +60,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["DiffEqDevTools", "ForwardDiff", "Random", "SafeTestsets", "Test", "ODEProblemLibrary", "StaticArrays"] +test = ["DiffEqDevTools", "ForwardDiff", "Random", "SafeTestsets", "Test", "ODEProblemLibrary", "StaticArrays", "Enzyme", "LinearSolve"] diff --git a/lib/OrdinaryDiffEqBDF/test/bdf_convergence_tests.jl b/lib/OrdinaryDiffEqBDF/test/bdf_convergence_tests.jl index 8f7115649c..0dd80ea062 100644 --- a/lib/OrdinaryDiffEqBDF/test/bdf_convergence_tests.jl +++ b/lib/OrdinaryDiffEqBDF/test/bdf_convergence_tests.jl @@ -1,5 +1,5 @@ # This definitely needs cleaning -using OrdinaryDiffEqBDF, ODEProblemLibrary, DiffEqDevTools +using OrdinaryDiffEqBDF, ODEProblemLibrary, DiffEqDevTools, ADTypes, Enzyme, LinearSolve using OrdinaryDiffEqNonlinearSolve: NLFunctional, NLAnderson, NonlinearSolveAlg using Test, Random Random.seed!(100) @@ -39,6 +39,27 @@ dts3 = 1 .// 2 .^ (12:-1:7) @test sim.𝒪est[:l2]≈1 atol=testTol @test sim.𝒪est[:l∞]≈1 atol=testTol + sim = test_convergence(dts, prob, QNDF1(autodiff = AutoFiniteDiff())) + @test sim.𝒪est[:final]≈1 atol=testTol + @test sim.𝒪est[:l2]≈1 atol=testTol + @test sim.𝒪est[:l∞]≈1 atol=testTol + + sim = test_convergence(dts, + prob, + QNDF1(autodiff = AutoEnzyme(mode = set_runtime_activity(Enzyme.Forward), + function_annotation = Enzyme.Const))) + @test sim.𝒪est[:final]≈1 atol=testTol + @test sim.𝒪est[:l2]≈1 atol=testTol + @test sim.𝒪est[:l∞]≈1 atol=testTol + + sim = test_convergence(dts, + prob, + QNDF1(autodiff = AutoEnzyme(mode = set_runtime_activity(Enzyme.Forward), + function_annotation = Enzyme.Const), linsolve = LinearSolve.KrylovJL())) + @test sim.𝒪est[:final]≈1 atol=testTol + @test sim.𝒪est[:l2]≈1 atol=testTol + @test sim.𝒪est[:l∞]≈1 atol=testTol + sim = test_convergence(dts3, prob, QNDF2()) @test sim.𝒪est[:final]≈2 atol=testTol @test sim.𝒪est[:l2]≈2 atol=testTol diff --git a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl index add39834a1..3a262c420d 100644 --- a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl +++ b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl @@ -72,7 +72,7 @@ import DiffEqBase: calculate_residuals, import Polyester using MacroTools, Adapt -import ADTypes: AutoFiniteDiff, AutoForwardDiff, AbstractADType +import ADTypes: AutoFiniteDiff, AutoForwardDiff, AbstractADType, AutoSparse import Accessors: @reset using SciMLStructures: canonicalize, Tunable, isscimlstructure diff --git a/lib/OrdinaryDiffEqCore/src/alg_utils.jl b/lib/OrdinaryDiffEqCore/src/alg_utils.jl index a7ef37a7bf..227882b8bd 100644 --- a/lib/OrdinaryDiffEqCore/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqCore/src/alg_utils.jl @@ -176,6 +176,7 @@ _get_fwd_chunksize(::Type{<:AutoForwardDiff{CS}}) where {CS} = Val(CS) _get_fwd_chunksize_int(::Type{<:AutoForwardDiff{CS}}) where {CS} = CS _get_fwd_chunksize(AD) = Val(0) _get_fwd_chunksize_int(AD) = 0 +_get_fwd_chunksize_int(::AutoForwardDiff{CS}) where {CS} = CS _get_fwd_tag(::AutoForwardDiff{CS, T}) where {CS, T} = T _get_fdtype(::AutoFiniteDiff{T1}) where {T1} = T1 diff --git a/lib/OrdinaryDiffEqCore/src/algorithms.jl b/lib/OrdinaryDiffEqCore/src/algorithms.jl index 1b2a535467..fb56ff608b 100644 --- a/lib/OrdinaryDiffEqCore/src/algorithms.jl +++ b/lib/OrdinaryDiffEqCore/src/algorithms.jl @@ -59,9 +59,16 @@ function DiffEqBase.remake( }, DAEAlgorithm{CS, AD, FDT, ST, CJ}}; kwargs...) where {CS, AD, FDT, ST, CJ} + + if haskey(kwargs, :autodiff) && kwargs[:autodiff] isa AutoForwardDiff + chunk_size = _get_fwd_chunksize(kwargs[:autodiff]) + else + chunk_size = Val{CS}() + end + T = SciMLBase.remaker_of(thing) T(; SciMLBase.struct_as_namedtuple(thing)..., - chunk_size = Val{CS}(), autodiff = thing.autodiff, standardtag = Val{ST}(), + chunk_size = chunk_size, autodiff = thing.autodiff, standardtag = Val{ST}(), concrete_jac = CJ === nothing ? CJ : Val{CJ}(), kwargs...) end diff --git a/lib/OrdinaryDiffEqCore/src/misc_utils.jl b/lib/OrdinaryDiffEqCore/src/misc_utils.jl index ff81df99b4..7194d4638b 100644 --- a/lib/OrdinaryDiffEqCore/src/misc_utils.jl +++ b/lib/OrdinaryDiffEqCore/src/misc_utils.jl @@ -158,10 +158,11 @@ end function _process_AD_choice( ad_alg::AutoForwardDiff{CS}, ::Val{CS2}, ::Val{FD}) where {CS, CS2, FD} # Non-default `chunk_size` - if CS2 != 0 + if (CS2 != 0) && (isnothing(CS) || (CS2 !== CS)) @warn "The `chunk_size` keyword is deprecated. Please use an `ADType` specifier. For now defaulting to using `AutoForwardDiff` with `chunksize=$(CS2)`." return _bool_to_ADType(Val{true}(), Val{CS2}(), Val{FD}()), Val{CS2}(), Val{FD}() end + _CS = CS === nothing ? 0 : CS return ad_alg, Val{_CS}(), Val{FD}() end @@ -186,3 +187,12 @@ function _process_AD_choice( end return ad_alg, Val{CS}(), ad_alg.fdtype end + +function _process_AD_choice(ad_alg::AutoSparse, cs2::Val{CS2}, fd::Val{FD}) where {CS2, FD} + _, cs, fd = _process_AD_choice(ad_alg.dense_ad, cs2, fd) + ad_alg, cs, fd +end + +function _process_AD_choice(ad_alg, cs2, fd) + ad_alg, cs2, fd +end diff --git a/lib/OrdinaryDiffEqDifferentiation/Project.toml b/lib/OrdinaryDiffEqDifferentiation/Project.toml index f32ee0d56a..8963e335a3 100644 --- a/lib/OrdinaryDiffEqDifferentiation/Project.toml +++ b/lib/OrdinaryDiffEqDifferentiation/Project.toml @@ -6,7 +6,10 @@ version = "1.4.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -15,16 +18,18 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] -ADTypes = "1.11" +ADTypes = "1.14" ArrayInterface = "7" DiffEqBase = "6" DiffEqDevTools = "2.44.4" +DifferentiationInterface = "0.6.48" FastBroadcast = "0.3" FiniteDiff = "2" ForwardDiff = "0.10" @@ -36,7 +41,6 @@ Random = "<0.0.1, 1" SafeTestsets = "0.1.0" SciMLBase = "2" SparseArrays = "1" -SparseDiffTools = "2" StaticArrayInterface = "1" StaticArrays = "1" Test = "<0.0.1, 1" diff --git a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl index 25a696c778..bb6cb4d00a 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl @@ -1,10 +1,7 @@ module OrdinaryDiffEqDifferentiation -import ADTypes: AutoFiniteDiff, AutoForwardDiff, AbstractADType - -import SparseDiffTools: SparseDiffTools, matrix_colors, forwarddiff_color_jacobian!, - forwarddiff_color_jacobian, ForwardColorJacCache, - default_chunk_size, getsize, JacVec +import ADTypes +import ADTypes: AutoFiniteDiff, AutoForwardDiff, AbstractADType, AutoSparse import ForwardDiff, FiniteDiff import ForwardDiff.Dual @@ -16,7 +13,7 @@ using DiffEqBase import LinearAlgebra import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!, lmul!, axpby!, opnorm, lu import LinearAlgebra: LowerTriangular, UpperTriangular -import SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros +import SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros, sparse import ArrayInterface import StaticArrayInterface @@ -27,7 +24,9 @@ import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, S using DiffEqBase: TimeGradientWrapper, UJacobianWrapper, TimeDerivativeWrapper, UDerivativeWrapper -using SciMLBase: AbstractSciMLOperator, constructorof +using SciMLBase: AbstractSciMLOperator, constructorof, @set +using SciMLOperators +import SparseMatrixColorings import OrdinaryDiffEqCore using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplicitAlgorithm, DAEAlgorithm, @@ -44,11 +43,16 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplici FastConvergence, Convergence, SlowConvergence, VerySlowConvergence, Divergence, NLStatus, MethodType, constvalue -import OrdinaryDiffEqCore: get_chunksize, resize_J_W!, resize_nlsolver!, alg_autodiff, - _get_fwd_tag +import OrdinaryDiffEqCore: get_chunksize, resize_J_W!, resize_nlsolver!, alg_autodiff, _get_fwd_tag + +using ConstructionBase + +import DifferentiationInterface as DI using FastBroadcast: @.. +using ConcreteStructs: @concrete + @static if isdefined(DiffEqBase, :OrdinaryDiffEqTag) import DiffEqBase: OrdinaryDiffEqTag else @@ -59,5 +63,6 @@ include("alg_utils.jl") include("linsolve_utils.jl") include("derivative_utils.jl") include("derivative_wrappers.jl") +include("operators.jl") end diff --git a/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl index 354bbcb463..f9ccbbac61 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl @@ -41,58 +41,113 @@ end function DiffEqBase.prepare_alg( alg::Union{ - OrdinaryDiffEqAdaptiveImplicitAlgorithm{0, AD, + OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD, FDT}, - OrdinaryDiffEqImplicitAlgorithm{0, AD, FDT}, - DAEAlgorithm{0, AD, FDT}, - OrdinaryDiffEqExponentialAlgorithm{0, AD, FDT}}, + OrdinaryDiffEqImplicitAlgorithm{CS, AD, FDT}, + DAEAlgorithm{CS, AD, FDT}, + OrdinaryDiffEqExponentialAlgorithm{CS, AD, FDT}}, u0::AbstractArray{T}, - p, prob) where {AD, FDT, T} + p, prob) where {CS, AD, FDT, T} - # If not using autodiff or norecompile mode or very large bitsize (like a dual number u0 already) - # don't use a large chunksize as it will either error or not be beneficial - # If prob.f.f is a FunctionWrappersWrappers from ODEFunction, need to set chunksize to 1 - if alg_autodiff(alg) isa AutoForwardDiff && ((prob.f isa ODEFunction && - prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper) || - (isbitstype(T) && sizeof(T) > 24)) - return remake( - alg, autodiff = AutoForwardDiff(chunksize = 1, tag = alg_autodiff(alg).tag)) + prepped_AD = prepare_ADType(alg_autodiff(alg), prob, u0, p, standardtag(alg)) + + sparse_prepped_AD = prepare_user_sparsity(prepped_AD, prob) + + # if u0 is a StaticArray or eltype is Complex etc. don't use sparsity + if (((typeof(u0) <: StaticArray) || (eltype(u0) <: Complex) || (!(prob.f isa DAEFunction) && prob.f.mass_matrix isa MatrixOperator)) && sparse_prepped_AD isa AutoSparse) + @warn "Input type or problem definition is incompatible with sparse automatic differentiation. Switching to using dense automatic differentiation." + autodiff = ADTypes.dense_ad(sparse_prepped_AD) + else + autodiff = sparse_prepped_AD + end + + return remake(alg, autodiff = autodiff) +end + +function prepare_ADType(autodiff_alg::AutoSparse, prob, u0, p, standardtag) + SciMLBase.@set autodiff_alg.dense_ad = prepare_ADType(ADTypes.dense_ad(autodiff_alg), prob, u0, p, standardtag) +end + +function prepare_ADType(autodiff_alg::AutoForwardDiff, prob, u0, p, standardtag) + tag = if standardtag + ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(u0)) + else + nothing end + T = eltype(u0) + + fwd_cs = OrdinaryDiffEqCore._get_fwd_chunksize_int(autodiff_alg) + + cs = fwd_cs == 0 ? nothing : fwd_cs + + if ((prob.f isa ODEFunction && + prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper) || + (isbitstype(T) && sizeof(T) > 24)) && (cs == 0 || isnothing(cs)) + return AutoForwardDiff{1}(tag) + else + return AutoForwardDiff{cs}(tag) + end +end + +function prepare_ADType(alg::AutoFiniteDiff, prob, u0, p, standardtag) # If the autodiff alg is AutoFiniteDiff, prob.f.f isa FunctionWrappersWrapper, # and fdtype is complex, fdtype needs to change to something not complex - if alg_autodiff(alg) isa AutoFiniteDiff - if alg_difftype(alg) == Val{:complex} && (prob.f isa ODEFunction && - prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper) - @warn "AutoFiniteDiff fdtype complex is not compatible with this function" - return remake(alg, autodiff = AutoFiniteDiff(fdtype = Val{:forward}())) - end - return alg + if alg.fdtype == Val{:complex}() && (prob.f isa ODEFunction && prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper) + @warn "AutoFiniteDiff fdtype complex is not compatible with this function" + return AutoFiniteDiff(fdtype = Val{:forward}()) end + return alg +end + +function prepare_user_sparsity(ad_alg, prob) + jac_prototype = prob.f.jac_prototype + sparsity = prob.f.sparsity + + if !isnothing(sparsity) && !(ad_alg isa AutoSparse) + if sparsity isa SparseMatrixCSC + if prob.f.mass_matrix isa UniformScaling + idxs = diagind(sparsity) + @. @view(sparsity[idxs]) = 1 + + if !isnothing(jac_prototype) + @. @view(jac_prototype[idxs]) = 1 + end + else + idxs = findall(!iszero, prob.f.mass_matrix) + for idx in idxs + sparsity[idx] = prob.f.mass_matrix[idx] + end - L = StaticArrayInterface.known_length(typeof(u0)) - if L === nothing # dynamic sized - # If chunksize is zero, pick chunksize right at the start of solve and - # then do function barrier to infer the full solve - x = if prob.f.colorvec === nothing - length(u0) - else - maximum(prob.f.colorvec) + if !isnothing(jac_prototype) + for idx in idxs + jac_prototype[idx] = f.mass_matrix[idx] + end + end + end end - cs = ForwardDiff.pickchunksize(x) - return remake(alg, - autodiff = AutoForwardDiff( - chunksize = cs)) - else # statically sized - cs = pick_static_chunksize(Val{L}()) - cs = SciMLBase._unwrap_val(cs) - return remake( - alg, autodiff = AutoForwardDiff(chunksize = cs)) + # KnownJacobianSparsityDetector needs an AbstractMatrix + sparsity = sparsity isa MatrixOperator ? sparsity.A : sparsity + + color_alg = DiffEqBase.has_colorvec(prob.f) ? + SparseMatrixColorings.ConstantColoringAlgorithm( + sparsity, prob.f.colorvec) : SparseMatrixColorings.GreedyColoringAlgorithm() + + sparsity_detector = ADTypes.KnownJacobianSparsityDetector(sparsity) + + return AutoSparse( + ad_alg, sparsity_detector = sparsity_detector, coloring_algorithm = color_alg) + else + return ad_alg end end +function prepare_ADType(alg::AbstractADType, prob, u0,p,standardtag) + return alg +end + @generated function pick_static_chunksize(::Val{chunksize}) where {chunksize} x = ForwardDiff.pickchunksize(chunksize) :(Val{$x}()) diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index 21b23c60fa..ec564b944f 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -39,7 +39,33 @@ function calc_tderivative!(integrator, cache, dtd1, repeat_step) else tf.uprev = uprev tf.p = p - derivative!(dT, tf, t, du2, integrator, cache.grad_config) + alg = unwrap_alg(integrator, true) + + autodiff_alg = ADTypes.dense_ad(alg_autodiff(alg)) + + # Convert t to eltype(dT) if using ForwardDiff, to make FunctionWrappers work + t = autodiff_alg isa AutoForwardDiff ? convert(eltype(dT),t) : t + + grad_config_tup = cache.grad_config + + if autodiff_alg isa AutoFiniteDiff + grad_config = diffdir(integrator) > 0 ? grad_config_tup[1] : grad_config_tup[2] + else + grad_config = grad_config_tup[1] + end + + if integrator.iter == 1 + try + DI.derivative!( + tf, linsolve_tmp, dT, grad_config, autodiff_alg, t) + catch e + throw(FirstAutodiffTgradError(e)) + end + else + DI.derivative!(tf, linsolve_tmp, dT, grad_config, autodiff_alg, t) + end + + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) end end @@ -48,7 +74,7 @@ function calc_tderivative!(integrator, cache, dtd1, repeat_step) end function calc_tderivative(integrator, cache) - @unpack t, dt, uprev, u, f, p = integrator + @unpack t, dt, uprev, u, f, p, alg = integrator # Time derivative if DiffEqBase.has_tgrad(f) @@ -57,7 +83,24 @@ function calc_tderivative(integrator, cache) tf = cache.tf tf.u = uprev tf.p = p - dT = derivative(tf, t, integrator) + + autodiff_alg = ADTypes.dense_ad(alg_autodiff(alg)) + + if alg_autodiff isa AutoFiniteDiff + autodiff_alg = SciMLBase.@set autodiff_alg.dir = diffdir(integrator) + end + + if integrator.iter == 1 + try + dT = DI.derivative(tf, autodiff_alg, t) + catch e + throw(FirstAutodiffTgradError(e)) + end + else + dT = DI.derivative(tf, autodiff_alg, t) + end + + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) end dT end @@ -97,7 +140,6 @@ function calc_J(integrator, cache, next_step::Bool = false) uf.f = nlsolve_f(f, alg) uf.p = p uf.t = t - J = jacobian(uf, uprev, integrator) end @@ -613,6 +655,7 @@ end W = J else W = J - mass_matrix * inv(dtgamma) + if !isa(W, Number) W = DiffEqBase.default_factorize(W) end @@ -677,7 +720,7 @@ function update_W!(nlsolver::AbstractNLSolver, nothing end -function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, +function build_J_W(alg, u, uprev, p, t, dt, f::F, jac_config, ::Type{uEltypeNoUnits}, ::Val{IIP}) where {IIP, uEltypeNoUnits, F} # TODO - make J, W AbstractSciMLOperators (lazily defined with scimlops functionality) # TODO - if jvp given, make it SciMLOperators.FunctionOperator @@ -708,12 +751,10 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, # If the user has chosen GMRES but no sparse Jacobian, assume that the dense # Jacobian is a bad idea and create a fully matrix-free solver. This can # be overridden with concrete_jac. + jacvec = JVPCache(f, copy(u), u, p, t, autodiff = alg_autodiff(alg)) - _f = islin ? (isode ? f.f : f.f1.f) : f - jacvec = JacVec((du, u, p, t) -> _f(du, u, p, t), copy(u), p, t; - autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag()) J = jacvec - W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec) + W = WOperator{IIP}(f.mass_matrix, promote(t, dt)[2], J, u, jacvec) elseif alg.linsolve !== nothing && !LinearSolve.needs_concrete_A(alg.linsolve) || concrete_jac(alg) !== nothing && concrete_jac(alg) # The linear solver does not need a concrete Jacobian, but the user has @@ -721,21 +762,29 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, # Thus setup JacVec and a concrete J, using sparsity when possible _f = islin ? (isode ? f.f : f.f1.f) : f J = if f.jac_prototype === nothing - ArrayInterface.undefmatrix(u) + if alg_autodiff(alg) isa AutoSparse + if isnothing(f.sparsity) + !isnothing(jac_config) ? + convert.( + eltype(u), SparseMatrixColorings.sparsity_pattern(jac_config[1])) : + spzeros(eltype(u), length(u), length(u)) + elseif eltype(f.sparsity) == Bool + convert.(eltype(u), f.sparsity) + else + f.sparsity + end + else + ArrayInterface.zeromatrix(u) + end else deepcopy(f.jac_prototype) end W = if J isa StaticMatrix StaticWOperator(J, false) else - __f = if IIP - (du, u, p, t) -> _f(du, u, p, t) - else - (u, p, t) -> _f(u, p, t) - end - jacvec = JacVec(__f, copy(u), p, t; - autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag()) - WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec) + jacvec = JVPCache(f, copy(u), u, p, t, autodiff = alg_autodiff(alg)) + + WOperator{IIP}(f.mass_matrix, promote(t, dt)[2], J, u, jacvec) end else J = if !IIP && DiffEqBase.has_jac(f) @@ -745,7 +794,19 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, f.jac(uprev, p, t) end elseif f.jac_prototype === nothing - ArrayInterface.undefmatrix(u) + if alg_autodiff(alg) isa AutoSparse + + if isnothing(f.sparsity) + !isnothing(jac_config) ? convert.(eltype(u), SparseMatrixColorings.sparsity_pattern(jac_config[1])) : + spzeros(eltype(u), length(u), length(u)) + elseif eltype(f.sparsity) == Bool + convert.(eltype(u), f.sparsity) + else + f.sparsity + end + else + ArrayInterface.zeromatrix(u) + end else deepcopy(f.jac_prototype) end @@ -817,13 +878,13 @@ function resize_J_W!(cache, integrator, i) islin = f isa Union{ODEFunction, SplitFunction} && islinear(nf.f) if !islin if cache.J isa AbstractSciMLOperator - resize!(cache.J, i) + resize_JVPCache!(cache.J, f, cache.du1, integrator.u, alg_autodiff(integrator.alg)) elseif f.jac_prototype !== nothing J = similar(f.jac_prototype, i, i) J = MatrixOperator(J; update_func! = f.jac) end if cache.W.jacvec isa AbstractSciMLOperator - resize!(cache.W.jacvec, i) + resize_JVPCache!(cache.W.jacvec, f, cache.du1, integrator.u, alg_autodiff(integrator.alg)) end cache.W = WOperator{DiffEqBase.isinplace(integrator.sol.prob)}(f.mass_matrix, integrator.dt, @@ -841,3 +902,6 @@ function resize_J_W!(cache, integrator, i) nothing end + +getsize(::Val{N}) where {N} = N +getsize(N::Integer) = N \ No newline at end of file diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl index 3ef9cf3c4f..32d3ee3935 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl @@ -4,7 +4,7 @@ const FIRST_AUTODIFF_TGRAD_MESSAGE = """ with automatic differentiation. Methods to fix this include: 1. Turn off automatic differentiation (e.g. Rosenbrock23() becomes - Rosenbrock23(autodiff=false)). More details can be found at + Rosenbrock23(autodiff=AutoFiniteDiff())). More details can be found at https://docs.sciml.ai/DiffEqDocs/stable/features/performance_overloads/ 2. Improving the compatibility of `f` with ForwardDiff.jl automatic differentiation (using tools like PreallocationTools.jl). More details @@ -46,7 +46,7 @@ const FIRST_AUTODIFF_JAC_MESSAGE = """ with automatic differentiation. Methods to fix this include: 1. Turn off automatic differentiation (e.g. Rosenbrock23() becomes - Rosenbrock23(autodiff=false)). More details can befound at + Rosenbrock23(autodiff = AutoFiniteDiff())). More details can befound at https://docs.sciml.ai/DiffEqDocs/stable/features/performance_overloads/ 2. Improving the compatibility of `f` with ForwardDiff.jl automatic differentiation (using tools like PreallocationTools.jl). More details @@ -75,198 +75,171 @@ function Base.showerror(io::IO, e::FirstAutodiffJacError) Base.showerror(io, e.e) end -function derivative!(df::AbstractArray{<:Number}, f, - x::Union{Number, AbstractArray{<:Number}}, fx::AbstractArray{<:Number}, - integrator, grad_config) +function jacobian(f, x::AbstractArray{<:Number}, integrator) alg = unwrap_alg(integrator, true) - tmp = length(x) # We calculate derivative for all elements in gradient - autodiff_alg = alg_autodiff(alg) - if autodiff_alg isa AutoForwardDiff - T = if standardtag(alg) - typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(df))) - else - typeof(ForwardDiff.Tag(f, eltype(df))) - end + + # Update stats.nf - xdual = Dual{T, eltype(df), 1}(convert(eltype(df), x), - ForwardDiff.Partials((one(eltype(df)),))) + dense = ADTypes.dense_ad(alg_autodiff(alg)) - if integrator.iter == 1 - try - f(grad_config, xdual) - catch e - throw(FirstAutodiffTgradError(e)) - end - else - f(grad_config, xdual) - end + if dense isa AutoForwardDiff + sparsity, colorvec = sparsity_colorvec(integrator.f, x) + maxcolor = maximum(colorvec) + chunk_size = (get_chunksize(alg) == Val(0) || get_chunksize(alg) == Val(nothing) ) ? nothing : get_chunksize(alg) + num_of_chunks = div(maxcolor, isnothing(chunk_size) ? + getsize(ForwardDiff.pickchunksize(maxcolor)) : _unwrap_val(chunk_size), + RoundUp) - df .= first.(ForwardDiff.partials.(grad_config)) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - elseif autodiff_alg isa AutoFiniteDiff - FiniteDiff.finite_difference_gradient!(df, f, x, grad_config, - dir = diffdir(integrator)) - fdtype = alg_difftype(alg) - if fdtype == Val{:forward} || fdtype == Val{:central} - tmp *= 2 - if eltype(df) <: Complex - tmp *= 2 - end + integrator.stats.nf += num_of_chunks + + elseif dense isa AutoFiniteDiff + sparsity, colorvec = sparsity_colorvec(integrator.f, x) + if dense.fdtype == Val(:forward) + integrator.stats.nf += maximum(colorvec) + 1 + elseif dense.fdtype == Val(:central) + integrator.stats.nf += 2*maximum(colorvec) + elseif dense.fdtype == Val(:complex) + integrator.stats.nf += maximum(colorvec) end - integrator.stats.nf += tmp + else + integrator.stats.nf += 1 + end + + + if dense isa AutoFiniteDiff + dense = SciMLBase.@set dense.dir = diffdir(integrator) + end + + autodiff_alg = alg_autodiff(alg) + + if alg_autodiff(alg) isa AutoSparse + autodiff_alg = SciMLBase.@set autodiff_alg.dense_ad = dense else - error("$alg_autodiff not yet supported in derivative! function") + autodiff_alg = dense end - nothing -end -function derivative(f, x::Union{Number, AbstractArray{<:Number}}, - integrator) - local d - tmp = length(x) # We calculate derivative for all elements in gradient - alg = unwrap_alg(integrator, true) - if alg_autodiff(alg) isa AutoForwardDiff - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - if integrator.iter == 1 + if integrator.iter == 1 try - d = ForwardDiff.derivative(f, x) + jac = DI.jacobian(f, autodiff_alg, x) catch e - throw(FirstAutodiffTgradError(e)) + throw(FirstAutodiffJacError(e)) end else - d = ForwardDiff.derivative(f, x) - end - elseif alg_autodiff(alg) isa AutoFiniteDiff - d = FiniteDiff.finite_difference_derivative(f, x, alg_difftype(alg), - dir = diffdir(integrator)) - if alg_difftype(alg) === Val{:central} || alg_difftype(alg) === Val{:forward} - tmp *= 2 - end - integrator.stats.nf += tmp - d - else - error("$alg_autodiff not yet supported in derivative function") + jac = DI.jacobian(f, autodiff_alg, x) end -end -jacobian_autodiff(f, x, odefun, alg) = (ForwardDiff.derivative(f, x), 1, alg) -function jacobian_autodiff(f, x::AbstractArray, odefun, alg) - jac_prototype = odefun.jac_prototype - sparsity, colorvec = sparsity_colorvec(odefun, x) - maxcolor = maximum(colorvec) - chunk_size = get_chunksize(alg) === Val(0) ? nothing : get_chunksize(alg) - num_of_chunks = chunk_size === nothing ? - Int(ceil(maxcolor / getsize(ForwardDiff.pickchunksize(maxcolor)))) : - Int(ceil(maxcolor / _unwrap_val(chunk_size))) - ( - forwarddiff_color_jacobian(f, x, colorvec = colorvec, sparsity = sparsity, - jac_prototype = jac_prototype, chunksize = chunk_size), - num_of_chunks) + return jac end -function _nfcount(N, ::Type{diff_type}) where {diff_type} - if diff_type === Val{:complex} - tmp = N - elseif diff_type === Val{:forward} - tmp = N + 1 +# fallback for scalar x, is needed for calc_J to work +function jacobian(f, x, integrator) + alg = unwrap_alg(integrator, true) + + dense = ADTypes.dense_ad(alg_autodiff(alg)) + + if dense isa AutoForwardDiff + integrator.stats.nf += 1 + elseif dense isa AutoFiniteDiff + if dense.fdtype == Val(:forward) + integrator.stats.nf += 2 + elseif dense.fdtype == Val(:central) + integrator.stats.nf += 2 + elseif dense.fdtype == Val(:complex) + integrator.stats.nf += 1 + end else - tmp = 2N + integrator.stats.nf += 1 end - tmp -end -function jacobian_finitediff(f, x, ::Type{diff_type}, dir, colorvec, sparsity, - jac_prototype) where {diff_type} - (FiniteDiff.finite_difference_derivative(f, x, diff_type, eltype(x), dir = dir), 2) -end -function jacobian_finitediff(f, x::AbstractArray, ::Type{diff_type}, dir, colorvec, - sparsity, jac_prototype) where {diff_type} - f_in = diff_type === Val{:forward} ? f(x) : similar(x) - ret_eltype = eltype(f_in) - J = FiniteDiff.finite_difference_jacobian(f, x, diff_type, ret_eltype, f_in, - dir = dir, colorvec = colorvec, - sparsity = sparsity, - jac_prototype = jac_prototype) - return J, _nfcount(maximum(colorvec), diff_type) -end -function jacobian(f, x, integrator) - alg = unwrap_alg(integrator, true) - local tmp - if alg_autodiff(alg) isa AutoForwardDiff - if integrator.iter == 1 - try - J, tmp = jacobian_autodiff(f, x, integrator.f, alg) - catch e - throw(FirstAutodiffJacError(e)) - end - else - J, tmp = jacobian_autodiff(f, x, integrator.f, alg) + if dense isa AutoFiniteDiff + dense = SciMLBase.@set dense.dir = diffdir(integrator) + end + + autodiff_alg = alg_autodiff(alg) + + if autodiff_alg isa AutoSparse + autodiff_alg = SciMLBase.@set autodiff_alg.dense_ad = dense + else + autodiff_alg = dense + end + + if integrator.iter == 1 + try + jac = DI.derivative(f, autodiff_alg, x) + catch e + throw(FirstAutodiffJacError(e)) end - elseif alg_autodiff(alg) isa AutoFiniteDiff - jac_prototype = integrator.f.jac_prototype - sparsity, colorvec = sparsity_colorvec(integrator.f, x) - dir = diffdir(integrator) - J, tmp = jacobian_finitediff(f, x, alg_difftype(alg), dir, colorvec, sparsity, - jac_prototype) else - bleh + jac = DI.derivative(f, autodiff_alg, x) end - integrator.stats.nf += tmp - J -end -function jacobian_finitediff_forward!(J, f, x, jac_config, forwardcache, integrator) - (FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config, forwardcache, - dir = diffdir(integrator)); - maximum(jac_config.colorvec)) -end -function jacobian_finitediff!(J, f, x, jac_config, integrator) - (FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config, - dir = diffdir(integrator)); - 2 * maximum(jac_config.colorvec)) + return jac end function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}, fx::AbstractArray{<:Number}, integrator::DiffEqBase.DEIntegrator, jac_config) alg = unwrap_alg(integrator, true) - if alg_autodiff(alg) isa AutoForwardDiff - if integrator.iter == 1 - try - forwarddiff_color_jacobian!(J, f, x, jac_config) - catch e - throw(FirstAutodiffJacError(e)) - end + + dense = ADTypes.dense_ad(alg_autodiff(alg)) + + if dense isa AutoForwardDiff + if alg_autodiff(alg) isa AutoSparse + integrator.stats.nf += maximum(SparseMatrixColorings.ncolors(jac_config[1])) else - forwarddiff_color_jacobian!(J, f, x, jac_config) + sparsity, colorvec = sparsity_colorvec(integrator.f, x) + maxcolor = maximum(colorvec) + chunk_size = (get_chunksize(alg) == Val(0) || get_chunksize(alg) == Val(nothing)) ? nothing : get_chunksize(alg) + num_of_chunks = chunk_size === nothing ? + Int(ceil(maxcolor / getsize(ForwardDiff.pickchunksize(maxcolor)))) : + Int(ceil(maxcolor / _unwrap_val(chunk_size))) + + integrator.stats.nf += num_of_chunks + end + + elseif dense isa AutoFiniteDiff + sparsity, colorvec = sparsity_colorvec(integrator.f, x) + if dense.fdtype == Val(:forward) + integrator.stats.nf += maximum(colorvec) + 1 + elseif dense.fdtype == Val(:central) + integrator.stats.nf += 2 * maximum(colorvec) + elseif dense.fdtype == Val(:complex) + integrator.stats.nf += maximum(colorvec) end - OrdinaryDiffEqCore.increment_nf!(integrator.stats, maximum(jac_config.colorvec)) - elseif alg_autodiff(alg) isa AutoFiniteDiff - isforward = alg_difftype(alg) === Val{:forward} - if isforward - forwardcache = get_tmp_cache(integrator, alg, unwrap_cache(integrator, true))[2] - f(forwardcache, x) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - tmp = jacobian_finitediff_forward!(J, f, x, jac_config, forwardcache, - integrator) - else # not forward difference - tmp = jacobian_finitediff!(J, f, x, jac_config, integrator) + else + integrator.stats.nf += 1 + end + + if dense isa AutoFiniteDiff + config = diffdir(integrator) > 0 ? jac_config[1] : jac_config[2] + else + config = jac_config[1] + end + + if integrator.iter == 1 + try + DI.jacobian!(f, fx, J, config, alg_autodiff(alg), x) + catch e + throw(FirstAutodiffJacError(e)) end - integrator.stats.nf += tmp else - error("$alg_autodiff not yet supported in jacobian! function") + DI.jacobian!(f, fx, J, config, alg_autodiff(alg), x) end + nothing end -function build_jac_config(alg, f::F1, uf::F2, du1, uprev, u, tmp, du2) where {F1, F2} +function build_jac_config(alg, f::F1, uf::F2, du1, uprev, + u, tmp, du2) where {F1, F2} + haslinsolve = hasfield(typeof(alg), :linsolve) - if !DiffEqBase.has_jac(f) && # No Jacobian if has analytical solution - (!DiffEqBase.has_Wfact_t(f)) && - ((concrete_jac(alg) === nothing && (!haslinsolve || (haslinsolve && # No Jacobian if linsolve doesn't want it - (alg.linsolve === nothing || LinearSolve.needs_concrete_A(alg.linsolve))))) || - (concrete_jac(alg) !== nothing && concrete_jac(alg))) # Jacobian if explicitly asked for + if !DiffEqBase.has_jac(f) && + (!DiffEqBase.has_Wfact_t(f)) && + ((concrete_jac(alg) === nothing && (!haslinsolve || (haslinsolve && + (alg.linsolve === nothing || LinearSolve.needs_concrete_A(alg.linsolve))))) || + (concrete_jac(alg) !== nothing && concrete_jac(alg))) + jac_prototype = f.jac_prototype if jac_prototype isa SparseMatrixCSC @@ -278,39 +251,36 @@ function build_jac_config(alg, f::F1, uf::F2, du1, uprev, u, tmp, du2) where {F1 @. @view(jac_prototype[idxs]) = @view(f.mass_matrix[idxs]) end end + uf = SciMLBase.@set uf.f = SciMLBase.unwrapped_f(uf.f) - sparsity, colorvec = sparsity_colorvec(f, u) - if alg_autodiff(alg) isa AutoForwardDiff - _chunksize = get_chunksize(alg) === Val(0) ? nothing : get_chunksize(alg) # SparseDiffEq uses different convection... - T = if standardtag(alg) - typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(u))) - else - typeof(ForwardDiff.Tag(uf, eltype(u))) - end + autodiff_alg = alg_autodiff(alg) + dense = autodiff_alg isa AutoSparse ? ADTypes.dense_ad(autodiff_alg) : autodiff_alg - if _chunksize === Val{nothing}() - _chunksize = nothing - end - jac_config = ForwardColorJacCache(uf, uprev, _chunksize; colorvec = colorvec, - sparsity = sparsity, tag = T) - elseif alg_autodiff(alg) isa AutoFiniteDiff - if alg_difftype(alg) !== Val{:complex} - jac_config = FiniteDiff.JacobianCache(tmp, du1, du2, alg_difftype(alg), - colorvec = colorvec, - sparsity = sparsity) + if dense isa AutoFiniteDiff + dir_forward = @set dense.dir = 1 + dir_reverse = @set dense.dir = -1 + + if autodiff_alg isa AutoSparse + autodiff_alg_forward = @set autodiff_alg.dense_ad = dir_forward + autodiff_alg_reverse = @set autodiff_alg.dense_ad = dir_reverse else - jac_config = FiniteDiff.JacobianCache(Complex{eltype(tmp)}.(tmp), - Complex{eltype(du1)}.(du1), nothing, - alg_difftype(alg), eltype(u), - colorvec = colorvec, - sparsity = sparsity) + autodiff_alg_forward = dir_forward + autodiff_alg_reverse = dir_reverse end + + jac_config_forward = DI.prepare_jacobian(uf, du1, autodiff_alg_forward, u) + jac_config_reverse = DI.prepare_jacobian(uf, du1, autodiff_alg_reverse, u) + + jac_config = (jac_config_forward, jac_config_reverse) else - error("$alg_autodiff not yet supported in build_jac_config function") + jac_config1 = DI.prepare_jacobian(uf, du1, alg_autodiff(alg), u) + jac_config = (jac_config1, jac_config1) end - else - jac_config = nothing + + else + jac_config = (nothing, nothing) end + jac_config end @@ -324,74 +294,74 @@ function get_chunksize(jac_config::ForwardDiff.JacobianConfig{ Val(N) end # don't degrade compile time information to runtime information -function resize_jac_config!(jac_config::SparseDiffTools.ForwardColorJacCache, i) - resize!(jac_config.fx, i) - resize!(jac_config.dx, i) - resize!(jac_config.t, i) - ps = SparseDiffTools.adapt.(DiffEqBase.parameterless_type(jac_config.dx), - SparseDiffTools.generate_chunked_partials(jac_config.dx, - 1:length(jac_config.dx), - Val(ForwardDiff.npartials(jac_config.t[1])))) - resize!(jac_config.p, length(ps)) - jac_config.p .= ps -end +function resize_jac_config!(cache, integrator) + if !isnothing(cache.jac_config) && !isnothing(cache.jac_config[1]) + uf = cache.uf + uf = SciMLBase.@set uf.f = SciMLBase.unwrapped_f(uf.f) -function resize_jac_config!(jac_config::FiniteDiff.JacobianCache, i) - resize!(jac_config, i) - jac_config -end + # for correct FiniteDiff dirs + autodiff_alg = alg_autodiff(integrator.alg) + if autodiff_alg isa AutoFiniteDiff + ad_right = SciMLBase.@set autodiff_alg.dir = 1 + ad_left = SciMLBase.@set autodiff_alg.dir = -1 + else + ad_right = autodiff_alg + ad_left = autodiff_alg + end -function resize_grad_config!(grad_config::AbstractArray, i) - resize!(grad_config, i) - grad_config + cache.jac_config = ([DI.prepare!_jacobian( + uf, cache.du1, config, ad, integrator.u) + for (ad, config) in zip( + (ad_right, ad_left), cache.jac_config)]...,) + end + cache.jac_config end -function resize_grad_config!(grad_config::ForwardDiff.DerivativeConfig, i) - resize!(grad_config.duals, i) - grad_config -end +function resize_grad_config!(cache, integrator) + if !isnothing(cache.grad_config) && !isnothing(cache.grad_config[1]) + + # for correct FiniteDiff dirs + autodiff_alg = ADTypes.dense_ad(alg_autodiff(integrator.alg)) + if autodiff_alg isa AutoFiniteDiff + ad_right = SciMLBase.@set autodiff_alg.dir = 1 + ad_left = SciMLBase.@set autodiff_alg.dir = -1 + else + ad_right = autodiff_alg + ad_left = autodiff_alg + end -function resize_grad_config!(grad_config::FiniteDiff.GradientCache, i) - @unpack fx, c1, c2 = grad_config - fx !== nothing && resize!(fx, i) - c1 !== nothing && resize!(c1, i) - c2 !== nothing && resize!(c2, i) - grad_config + cache.grad_config = ([DI.prepare!_derivative( + cache.tf, cache.du1, config, ad, integrator.t) + for (ad, config) in zip( + (ad_right, ad_left), cache.grad_config)]...,) + end + cache.grad_config end + + + + function build_grad_config(alg, f::F1, tf::F2, du1, t) where {F1, F2} if !DiffEqBase.has_tgrad(f) - if alg_autodiff(alg) isa AutoForwardDiff - T = if standardtag(alg) - typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(du1))) - else - typeof(ForwardDiff.Tag(f, eltype(du1))) - end + ad = ADTypes.dense_ad(alg_autodiff(alg)) - if du1 isa Array - dualt = Dual{T, eltype(du1), 1}(first(du1) * t, - ForwardDiff.Partials((one(eltype(du1)),))) - grad_config = similar(du1, typeof(dualt)) - fill!(grad_config, false) - else - grad_config = ArrayInterface.restructure(du1, - Dual{ - T, - eltype(du1), - 1 - }.(du1, - (ForwardDiff.Partials((one(eltype(du1)),)),)) .* - false) - end - elseif alg_autodiff(alg) isa AutoFiniteDiff - grad_config = FiniteDiff.GradientCache(du1, t, alg_difftype(alg)) + if ad isa AutoFiniteDiff + dir_true = @set ad.dir = 1 + dir_false = @set ad.dir = -1 + + grad_config_true = DI.prepare_derivative(tf, du1, dir_true, t) + grad_config_false = DI.prepare_derivative(tf, du1, dir_false, t) + + grad_config = (grad_config_true, grad_config_false) else - error("$alg_autodiff not yet supported in build_grad_config function") + grad_config1 = DI.prepare_derivative(tf,du1,ad,t) + grad_config = (grad_config1, grad_config1) end + return grad_config else - grad_config = nothing + return (nothing, nothing) end - grad_config end function sparsity_colorvec(f, x) @@ -407,7 +377,9 @@ function sparsity_colorvec(f, x) end end + col_alg = SparseMatrixColorings.GreedyColoringAlgorithm() + col_prob = SparseMatrixColorings.ColoringProblem() colorvec = DiffEqBase.has_colorvec(f) ? f.colorvec : - (isnothing(sparsity) ? (1:length(x)) : matrix_colors(sparsity)) + (isnothing(sparsity) ? (1:length(x)) : SparseMatrixColorings.column_colors(SparseMatrixColorings.coloring(sparsity, col_prob, col_alg))) sparsity, colorvec end diff --git a/lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl index aa48161e1b..48e94e8718 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl @@ -28,16 +28,16 @@ function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothi linres = solve!(linsolve; reltol) + ad = alg_autodiff(_alg) isa ADTypes.AutoSparse ? ADTypes.dense_ad(alg_autodiff(_alg)) : alg_autodiff(_alg) + # TODO: this ignores the add of the `f` count for add_steps! if integrator isa SciMLBase.DEIntegrator && _alg.linsolve !== nothing && !LinearSolve.needs_concrete_A(_alg.linsolve) && linsolve.A isa WOperator && linsolve.A.J isa AbstractSciMLOperator - if alg_autodiff(_alg) isa AutoForwardDiff - integrator.stats.nf += linres.iters - elseif alg_autodiff(_alg) isa AutoFiniteDiff - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2) * linres.iters + if ad isa ADTypes.AutoFiniteDiff || ad isa ADTypes.AutoFiniteDifferences + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2 * linres.iters) else - error("$alg_autodiff not yet supported in dolinsolve function") + integrator.stats.nf += linres.iters end end diff --git a/lib/OrdinaryDiffEqDifferentiation/src/operators.jl b/lib/OrdinaryDiffEqDifferentiation/src/operators.jl new file mode 100644 index 0000000000..e29aa1f6e0 --- /dev/null +++ b/lib/OrdinaryDiffEqDifferentiation/src/operators.jl @@ -0,0 +1,91 @@ +""" + JVPCache{T} <: AbstractSciMLOperator{T} + +JVPCache provides a JVP operator wrapper for performing the DifferentiationInterface pushforward operation. + +### Constructor + +```julia + JVPCache(f::DiffEqBase.AbstractDiffEqFunction, du, u, p, t; autodiff) +``` +JVPCache construction builds a DifferentiationInterface "prep" object using `prepare_pushforward!`. The "prep" object is used +when applying the operator. + +### Computing the JVP + +Computing the JVP is done with the DifferentiationInterface function `pushforward!`, which takes advantage of the preparation done upon construction. +""" +@concrete mutable struct JVPCache{T} <: SciMLOperators.AbstractSciMLOperator{T} + jvp_op + f + du + u + p + t +end + +SciMLBase.isinplace(::JVPCache) = true +ArrayInterface.can_setindex(::JVPCache) = false +function ArrayInterface.restructure(y::JVPCache, x::JVPCache) + @assert size(y)==size(x) "cannot restructure operators. ensure their sizes match." + return x +end + +function ConstructionBase.constructorof(::Type{<:JVPCache{T}}) where {T} + return JVPCache{T} +end + +Base.size(J::JVPCache) = (length(J.u), length(J.u)) + +function JVPCache(f::DiffEqBase.AbstractDiffEqFunction, du, u, p, t; autodiff) + jvp_op = prepare_jvp(f, du, u, p, t, autodiff) + return JVPCache{eltype(du)}(jvp_op, f, du, u, p, t) +end + +function (op::JVPCache)(v, u, p, t) + op.jvp_op(op.du, v, u, p, t) + return res +end + +function (op::JVPCache)(Jv, v, u, p, t) + op.jvp_op(Jv, v, u, p, t) + return Jv +end + +Base.:*(J::JVPCache, v::AbstractArray) = J.jvp_op(v, J.u, J.p, J.t) +Base.:*(J::JVPCache, v::Number) = J.jvp_op(v, J.u, J.p, J.t) + +function LinearAlgebra.mul!( + Jv::AbstractArray, J::JVPCache, v::AbstractArray) + J.jvp_op(Jv, v, J.u, J.p, J.t) + return Jv +end + +# helper functions + +function prepare_jvp(f::DiffEqBase.AbstractDiffEqFunction, du, u, p, t, autodiff) + SciMLBase.has_jvp(f) && return f.jvp + autodiff = autodiff isa AutoSparse ? ADTypes.dense_ad(autodiff) : autodiff + @assert DI.check_inplace(autodiff) "AD backend $(autodiff) doesn't support in-place problems." + di_prep = DI.prepare_pushforward(f, du, autodiff, u, (u,), DI.ConstantOrCache(p), DI.Constant(t)) + return (Jv, v, u, p, t) -> DI.pushforward!(f, du, (reshape(Jv, size(du)),), di_prep, autodiff, u, (reshape(v,size(u)),), DI.ConstantOrCache(p), DI.Constant(t)) +end + +function SciMLOperators.update_coefficients!(J::JVPCache, u, p, t) + J.u = u + J.p = p + J.t = t +end + + +function resize_JVPCache!(J::JVPCache,f, du, u, p, t, autodiff) + J.jvp_op = prepare_jvp(f, du, u, p, t, autodiff) + J.du = du + update_coefficients!(J, u, p, t) +end + +function resize_JVPCache!(J::JVPCache, f, du, u, autodiff) + J.jvp_op = prepare_jvp(f, du, u, J.p, J.t, autodiff) + J.du = du + update_coefficients!(J,u,J.p, J.t) +end \ No newline at end of file diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl b/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl index c753b3351a..6daf0b299f 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl @@ -91,16 +91,17 @@ function alg_cache(alg::RadauIIA3, u, rate_prototype, ::Type{uEltypeNoUnits}, fw1 = zero(rate_prototype) fw2 = zero(rate_prototype) - J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - W1 = similar(J, Complex{eltype(W1)}) - recursivefill!(W1, false) - du1 = zero(rate_prototype) tmp = zero(u) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) - jac_config = jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw12) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw12) + + J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) + W1 = similar(J, Complex{eltype(W1)}) + recursivefill!(W1, false) + linprob = LinearProblem(W1, _vec(cubuff); u0 = _vec(dw12)) linsolve = init( @@ -228,13 +229,6 @@ function alg_cache(alg::RadauIIA5, u, rate_prototype, ::Type{uEltypeNoUnits}, fw2 = zero(rate_prototype) fw3 = zero(rate_prototype) - J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - if J isa AbstractSciMLOperator - error("Non-concrete Jacobian not yet supported by RadauIIA5.") - end - W2 = similar(J, Complex{eltype(W1)}) - recursivefill!(W2, false) - du1 = zero(rate_prototype) tmp = zero(u) @@ -242,6 +236,13 @@ function alg_cache(alg::RadauIIA5, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1) + J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) + if J isa AbstractSciMLOperator + error("Non-concrete Jacobian not yet supported by RadauIIA5.") + end + W2 = similar(J, Complex{eltype(W1)}) + recursivefill!(W2, false) + linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1)) linsolve1 = init( linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true), @@ -409,15 +410,6 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits}, fw4 = zero(rate_prototype) fw5 = zero(rate_prototype) - J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - if J isa AbstractSciMLOperator - error("Non-concrete Jacobian not yet supported by RadauIIA5.") - end - W2 = similar(J, Complex{eltype(W1)}) - W3 = similar(J, Complex{eltype(W1)}) - recursivefill!(W2, false) - recursivefill!(W3, false) - du1 = zero(rate_prototype) tmp = zero(u) @@ -434,6 +426,15 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1) + J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) + if J isa AbstractSciMLOperator + error("Non-concrete Jacobian not yet supported by RadauIIA5.") + end + W2 = similar(J, Complex{eltype(W1)}) + W3 = similar(J, Complex{eltype(W1)}) + recursivefill!(W2, false) + recursivefill!(W3, false) + linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1)) linsolve1 = init( linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true), @@ -623,14 +624,6 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits} k = ks[1] - J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - if J isa AbstractSciMLOperator - error("Non-concrete Jacobian not yet supported by AdaptiveRadau.") - end - - W2 = [similar(J, Complex{eltype(W1)}) for _ in 1:((max_stages - 1) ÷ 2)] - recursivefill!.(W2, false) - du1 = zero(rate_prototype) tmp = zero(u) @@ -640,6 +633,14 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits} jac_config = build_jac_config(alg, f, uf, du1, uprev, u, zero(u), dw1) + J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) + if J isa AbstractSciMLOperator + error("Non-concrete Jacobian not yet supported by AdaptiveRadau.") + end + + W2 = [similar(J, Complex{eltype(W1)}) for _ in 1:((max_stages - 1) ÷ 2)] + recursivefill!.(W2, false) + linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1)) linsolve1 = init( linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true), diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl index d77a0c4563..bffaed0a0f 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl @@ -1,6 +1,6 @@ module OrdinaryDiffEqNonlinearSolve -import ADTypes: AutoFiniteDiff, AutoForwardDiff +using ADTypes import SciMLBase import SciMLBase: init, solve, solve!, remake diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl index d5e47246af..e873422e5e 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl @@ -529,9 +529,8 @@ function Base.resize!(nlcache::NLNewtonCache, ::AbstractNLSolver, integrator, i: resize!(nlcache.atmp, i) resize!(nlcache.dz, i) resize!(nlcache.du1, i) - if nlcache.jac_config !== nothing - resize_jac_config!(nlcache.jac_config, i) - end + + resize_jac_config!(nlcache, integrator) resize!(nlcache.weight, i) # resize J and W (or rather create new ones of appropriate size and type) @@ -539,3 +538,4 @@ function Base.resize!(nlcache::NLNewtonCache, ::AbstractNLSolver, integrator, i: nothing end + diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl index 8d504ef4da..f6071099c0 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl @@ -72,7 +72,8 @@ mutable struct DAEResidualJacobianWrapper{isAD, F, pType, duType, uType, alphaTy uprev::uprevType t::tType function DAEResidualJacobianWrapper(alg, f, p, α, invγdt, tmp, uprev, t) - isautodiff = alg_autodiff(alg) isa AutoForwardDiff + ad = ADTypes.dense_ad(alg_autodiff(alg)) + isautodiff = ad isa AutoForwardDiff if isautodiff tmp_du = PreallocationTools.dualcache(uprev) tmp_u = PreallocationTools.dualcache(uprev) @@ -86,6 +87,13 @@ mutable struct DAEResidualJacobianWrapper{isAD, F, pType, duType, uType, alphaTy end end +function SciMLBase.setproperties(wrap::DAEResidualJacobianWrapper, patch::NamedTuple) + for key in keys(patch) + setproperty!(wrap, key, patch[key]) + end + return wrap +end + is_autodiff(m::DAEResidualJacobianWrapper{isAD}) where {isAD} = isAD function (m::DAEResidualJacobianWrapper)(out, x) @@ -173,7 +181,7 @@ function build_nlsolver( if nlalg isa Union{NLNewton, NonlinearSolveAlg} nf = nlsolve_f(f, alg) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(true)) # TODO: check if the solver is iterative weight = zero(u) @@ -288,7 +296,7 @@ function build_nlsolver( tType = typeof(t) invγdt = inv(oneunit(t) * one(uTolType)) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false)) if nlalg isa NonlinearSolveAlg α = tTypeNoUnits(α) dt = tTypeNoUnits(dt) diff --git a/lib/OrdinaryDiffEqRosenbrock/Project.toml b/lib/OrdinaryDiffEqRosenbrock/Project.toml index 6e333a77e1..0fa86d7cc2 100644 --- a/lib/OrdinaryDiffEqRosenbrock/Project.toml +++ b/lib/OrdinaryDiffEqRosenbrock/Project.toml @@ -6,6 +6,7 @@ version = "1.8.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -26,6 +27,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" ADTypes = "1.11" DiffEqBase = "6.152.2" DiffEqDevTools = "2.44.4" +DifferentiationInterface = "0.6.48" FastBroadcast = "0.3.5" FiniteDiff = "2.24.0" ForwardDiff = "0.10.36" @@ -49,6 +51,7 @@ julia = "1.10" [extras] DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" @@ -59,4 +62,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["DiffEqDevTools", "Random", "OrdinaryDiffEqNonlinearSolve", "SafeTestsets", "Test", "LinearAlgebra", "LinearSolve", "ForwardDiff", "ODEProblemLibrary"] +test = ["DiffEqDevTools", "Random", "OrdinaryDiffEqNonlinearSolve", "SafeTestsets", "Test", "LinearAlgebra", "LinearSolve", "ForwardDiff", "ODEProblemLibrary", "Enzyme"] diff --git a/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl b/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl index 39b49f3573..10de0f3a9a 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl @@ -18,13 +18,14 @@ using MuladdMacro, FastBroadcast, RecursiveArrayTools import MacroTools using MacroTools: @capture using DiffEqBase: @def +import DifferentiationInterface as DI import LinearSolve import LinearSolve: UniformScaling import ForwardDiff using FiniteDiff using LinearAlgebra: mul!, diag, diagm, I, Diagonal, norm -import ADTypes: AutoForwardDiff, AbstractADType -import OrdinaryDiffEqCore +using ADTypes +import OrdinaryDiffEqCore, OrdinaryDiffEqDifferentiation using OrdinaryDiffEqDifferentiation: TimeDerivativeWrapper, TimeGradientWrapper, UDerivativeWrapper, UJacobianWrapper, @@ -32,7 +33,7 @@ using OrdinaryDiffEqDifferentiation: TimeDerivativeWrapper, TimeGradientWrapper, build_jac_config, issuccess_W, jacobian2W!, resize_jac_config!, resize_grad_config!, calc_W, calc_rosenbrock_differentiation!, build_J_W, - UJacobianWrapper, dolinsolve + UJacobianWrapper, dolinsolve, WOperator, resize_J_W! using Reexport @reexport using DiffEqBase diff --git a/lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl b/lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl index f83c1d29fc..becc128cc4 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl @@ -227,7 +227,7 @@ function gen_algcache(cacheexpr::Expr,constcachename::Symbol,algname::Symbol,tab function alg_cache(alg::$algname,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tf = TimeDerivativeWrapper(f,u,p) uf = UDerivativeWrapper(f,t,p) - J,W = build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(false)) + J,W = build_J_W(alg,u,uprev,p,t,dt,f, nothing, uEltypeNoUnits,Val(false)) $constcachename(tf,uf,$tabname(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,nothing) end function alg_cache(alg::$algname,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) @@ -238,7 +238,6 @@ function gen_algcache(cacheexpr::Expr,constcachename::Symbol,algname::Symbol,tab fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J,W = build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) weight = similar(u, uEltypeNoUnits) @@ -247,12 +246,15 @@ function gen_algcache(cacheexpr::Expr,constcachename::Symbol,algname::Symbol,tab tf = TimeGradientWrapper(f,uprev,p) uf = UJacobianWrapper(f,t,p) linsolve_tmp = zero(rate_prototype) + + grad_config = build_grad_config(alg,f,tf,du1,t) + jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,du2) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) + linprob = LinearProblem(W,_vec(linsolve_tmp); u0=_vec(tmp)) linsolve = init(linprob,alg.linsolve,alias = LinearAliasSpecifier(alias_A=true,alias_b=true), Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))), - Pr = Diagonal(_vec(weight))) - grad_config = build_grad_config(alg,f,tf,du1,t) - jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,du2) + Pr = Diagonal(_vec(weight))) $cachename($(valsyms...)) end end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/integrator_interface.jl b/lib/OrdinaryDiffEqRosenbrock/src/integrator_interface.jl index f99dc3a089..9cc3a7960b 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/integrator_interface.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/integrator_interface.jl @@ -1,8 +1,8 @@ function resize_non_user_cache!(integrator::ODEIntegrator, cache::RosenbrockMutableCache, i) - cache.J = similar(cache.J, i, i) - cache.W = similar(cache.W, i, i) - resize_jac_config!(cache.jac_config, i) - resize_grad_config!(cache.grad_config, i) + resize_J_W!(cache, integrator, i) + resize_jac_config!(cache, integrator) + resize_grad_config!(cache, integrator) + nothing end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index 3c4945725b..a358b81581 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -142,7 +142,6 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits}, fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) @@ -153,6 +152,11 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits}, uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) + grad_config = build_grad_config(alg, f, tf, du1, t) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + + J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) + linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) Pl, Pr = wrapprecs( alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, @@ -162,8 +166,7 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits}, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + algebraic_vars = f.mass_matrix === I ? nothing : [all(iszero, x) for x in eachcol(f.mass_matrix)] @@ -188,7 +191,6 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits}, fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) @@ -199,6 +201,12 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits}, tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) + + grad_config = build_grad_config(alg, f, tf, du1, t) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + + J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) + linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) Pl, Pr = wrapprecs( @@ -208,8 +216,7 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits}, linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true), Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + algebraic_vars = f.mass_matrix === I ? nothing : [all(iszero, x) for x in eachcol(f.mass_matrix)] @@ -241,7 +248,7 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tf = TimeDerivativeWrapper(f, u, p) uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false)) linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) Rosenbrock23ConstantCache(constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve, @@ -271,7 +278,7 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tf = TimeDerivativeWrapper(f, u, p) uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false)) linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) Rosenbrock32ConstantCache(constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve, @@ -339,7 +346,6 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits}, fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) @@ -349,6 +355,11 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits}, tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) + + grad_config = build_grad_config(alg, f, tf, du1, t) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) + linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) Pl, Pr = wrapprecs( alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, @@ -357,8 +368,7 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits}, linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true), Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + Rosenbrock33Cache(u, uprev, du, du1, du2, k1, k2, k3, k4, fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, @@ -372,7 +382,7 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tf = TimeDerivativeWrapper(f, u, p) uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false)) linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) Rosenbrock33ConstantCache(tf, uf, @@ -425,7 +435,6 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits}, fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) @@ -436,6 +445,12 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits}, tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) + + grad_config = build_grad_config(alg, f, tf, du1, t) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + + J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) + linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) Pl, Pr = wrapprecs( alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, @@ -444,8 +459,7 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits}, linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true), Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + Rosenbrock34Cache(u, uprev, du, du1, du2, k1, k2, k3, k4, fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, @@ -469,7 +483,7 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tf = TimeDerivativeWrapper(f, u, p) uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false)) linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) Rosenbrock34ConstantCache(tf, uf, @@ -619,7 +633,6 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits}, fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) @@ -630,6 +643,12 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits}, tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) + + grad_config = build_grad_config(alg, f, tf, du1, t) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + + J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) + linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) Pl, Pr = wrapprecs( alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, @@ -638,8 +657,7 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits}, linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true), Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + Rodas23WCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5, fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!, @@ -664,7 +682,6 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits}, fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) @@ -674,6 +691,12 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits}, tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) + + grad_config = build_grad_config(alg, f, tf, du1, t) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + + J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) + linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) Pl, Pr = wrapprecs( @@ -683,8 +706,7 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits}, linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true), Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + Rodas3PCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5, fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!, @@ -697,7 +719,7 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tf = TimeDerivativeWrapper(f, u, p) uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false)) linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) Rodas23WConstantCache(tf, uf, @@ -712,7 +734,7 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tf = TimeDerivativeWrapper(f, u, p) uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false)) linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) Rodas3PConstantCache(tf, uf, @@ -740,7 +762,7 @@ function alg_cache( ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tf = TimeDerivativeWrapper(f, u, p) uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false)) linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) @@ -772,9 +794,6 @@ function alg_cache( fsallast = zero(rate_prototype) dT = zero(rate_prototype) - # Build J and W matrices - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - # Temporary and helper variables tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) @@ -784,21 +803,25 @@ function alg_cache( tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) - linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) + + grad_config = build_grad_config(alg, f, tf, du1, t) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + + J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) Pl, Pr = wrapprecs( alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, nothing)..., weight, tmp) - linsolve = init( - linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true), - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) + linsolve_tmp = zero(rate_prototype) + linprob = LinearProblem(W, _vec(linsolve_tmp); u0=_vec(tmp)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + linsolve = init( + linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A=true, alias_b=true), + Pl=Pl, Pr=Pr, + assumptions=LinearSolve.OperatorAssumptions(true)) + # Return the cache struct with vectors RosenbrockCache( u, uprev, dense, du, du1, du2, dtC, dtd, ks, fsalfirst, fsallast, diff --git a/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl b/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl index c7987db872..dcc8931ba3 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl @@ -9,18 +9,21 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, neginvdtγ = -inv(dtγ) dto2 = dt / 2 tf.u = uprev - if cache.autodiff isa AutoForwardDiff - dT = ForwardDiff.derivative(tf, t) - else - dT = FiniteDiff.finite_difference_derivative(tf, t, dir = sign(dt)) + + autodiff_alg = cache.autodiff + + if autodiff_alg isa AutoFiniteDiff + autodiff_alg = SciMLBase.@set autodiff_alg.dir = sign(dt) end + dT = DI.derivative(tf, autodiff_alg,t) + mass_matrix = f.mass_matrix if uprev isa Number - J = ForwardDiff.derivative(uf, uprev) + J = DI.derivative(uf, autodiff_alg, uprev) W = neginvdtγ .+ J else - J = ForwardDiff.jacobian(uf, uprev) + J = DI.jacobian(uf, autodiff_alg, uprev) if mass_matrix isa UniformScaling W = neginvdtγ * mass_matrix + J else @@ -58,19 +61,22 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::RosenbrockCombinedConst # Time derivative tf.u = uprev - if cache.autodiff isa AutoForwardDiff - dT = ForwardDiff.derivative(tf, t) - else - dT = FiniteDiff.finite_difference_derivative(tf, t, dir = sign(dt)) + + autodiff_alg = cache.autodiff + + if autodiff_alg isa AutoFiniteDiff + autodiff_alg = SciMLBase.@set autodiff_alg.dir = sign(dt) end + dT = DI.derivative(tf, autodiff_alg, t) + # Jacobian uf.t = t if uprev isa AbstractArray - J = ForwardDiff.jacobian(uf, uprev) + J = DI.jacobian(uf, autodiff_alg, uprev) W = mass_matrix / dtgamma - J else - J = ForwardDiff.derivative(uf, uprev) + J = DI.derivative(uf, autodiff_alg, uprev) W = 1 / dtgamma - J end diff --git a/lib/OrdinaryDiffEqRosenbrock/test/ode_rosenbrock_tests.jl b/lib/OrdinaryDiffEqRosenbrock/test/ode_rosenbrock_tests.jl index bc5b65e080..fd2aee9e45 100644 --- a/lib/OrdinaryDiffEqRosenbrock/test/ode_rosenbrock_tests.jl +++ b/lib/OrdinaryDiffEqRosenbrock/test/ode_rosenbrock_tests.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEqRosenbrock, DiffEqDevTools, Test, LinearAlgebra, LinearSolve, ADTypes +using OrdinaryDiffEqRosenbrock, DiffEqDevTools, Test, LinearAlgebra, LinearSolve, ADTypes, Enzyme import ODEProblemLibrary: prob_ode_linear, prob_ode_2Dlinear, prob_ode_bigfloatlinear, prob_ode_bigfloat2Dlinear @@ -28,6 +28,14 @@ import LinearSolve sol = solve(prob, Rosenbrock23()) @test length(sol) < 20 + sim = test_convergence(dts, prob, Rosenbrock23(autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const))) + @test sim.𝒪est[:final]≈2 atol=testTol + + sol = solve(prob, Rosenbrock23(autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const))) + @test length(sol) < 20 + prob = prob_ode_bigfloat2Dlinear sim = test_convergence(dts, prob, Rosenbrock23(linsolve = QRFactorization())) @@ -54,6 +62,27 @@ import LinearSolve sol = solve(prob, Rosenbrock32()) @test length(sol) < 20 + sim = test_convergence(dts, + prob, + Rosenbrock32(autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const))) + @test sim.𝒪est[:final]≈3 atol=testTol + + sol = solve(prob, + Rosenbrock32(autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const))) + @test length(sol) < 20 + + sim = test_convergence(dts, + prob, + Rosenbrock32(autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const), linsolve = LinearSolve.KrylovJL())) + @test sim.𝒪est[:final]≈3 atol=testTol + + sol = solve(prob, + Rosenbrock32(autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const), linsolve = LinearSolve.KrylovJL())) + @test length(sol) < 20 ### ROS3P() prob = prob_ode_linear @@ -72,6 +101,21 @@ import LinearSolve sol = solve(prob, ROS3P()) @test length(sol) < 20 + sim = test_convergence(dts, + prob, + ROS3P( + autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const), + linsolve = LinearSolve.KrylovJL())) + @test sim.𝒪est[:final]≈3 atol=testTol + + sol = solve(prob, + ROS3P( + autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const), + linsolve = LinearSolve.KrylovJL())) + @test length(sol) < 20 + ### Rodas3() prob = prob_ode_linear @@ -90,6 +134,21 @@ import LinearSolve sol = solve(prob, Rodas3()) @test length(sol) < 20 + sim = test_convergence(dts, + prob, + Rodas3( + autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const), + linsolve = LinearSolve.KrylovJL())) + @test sim.𝒪est[:final]≈3 atol=testTol + + sol = solve(prob, + Rodas3( + autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const), + linsolve = LinearSolve.KrylovJL())) + @test length(sol) < 20 + ### ROS2 prob = prob_ode_linear @@ -475,6 +534,21 @@ import LinearSolve sol = solve(prob, Rodas23W()) @test length(sol) < 20 + sim = test_convergence(dts, + prob, + Rodas23W( + autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const), + linsolve = LinearSolve.KrylovJL())) + @test sim.𝒪est[:final] ≈ 2 atol = testTol + + sol = solve(prob, + Rodas23W( + autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const), + linsolve = LinearSolve.KrylovJL())) + @test length(sol) < 20 + println("Rodas3P") prob = prob_ode_linear @@ -495,6 +569,21 @@ import LinearSolve sol = solve(prob, Rodas3P()) @test length(sol) < 20 + sim = test_convergence(dts, + prob, + Rodas3P( + autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const), + linsolve = LinearSolve.KrylovJL())) + @test sim.𝒪est[:final]≈3 atol=testTol + + sol = solve(prob, + Rodas3P( + autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const), + linsolve = LinearSolve.KrylovJL())) + @test length(sol) < 20 + ### Rodas4 Algorithms println("RODAS") @@ -518,6 +607,11 @@ import LinearSolve sol = solve(prob, Rodas4(autodiff = AutoFiniteDiff())) @test length(sol) < 20 + sol = solve(prob, + Rodas4(autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const))) + @test length(sol) < 20 + sim = test_convergence(dts, prob, Rodas42(), dense_errors = true) @test sim.𝒪est[:final]≈5.1 atol=testTol @test sim.𝒪est[:L2]≈4 atol=testTol @@ -669,28 +763,55 @@ import LinearSolve sol = solve(prob, Rodas5Pe()) @test length(sol) < 20 - println("Rodas5Pr") + println("Rodas5P Enzyme Forward") prob = prob_ode_linear - sim = test_convergence(dts, prob, Rodas5Pr(), dense_errors = true) + sim = test_convergence(dts, prob, + Rodas5P(autodiff = AutoEnzyme(mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const)), + dense_errors = true) #@test sim.𝒪est[:final]≈5 atol=testTol #-- observed order > 6 @test sim.𝒪est[:L2]≈5 atol=testTol - sol = solve(prob, Rodas5Pr()) + sol = solve(prob, + Rodas5P(autodiff = AutoEnzyme(mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const))) @test length(sol) < 20 prob = prob_ode_2Dlinear - sim = test_convergence(dts, prob, Rodas5Pr(), dense_errors = true) + sim = test_convergence(dts, prob, + Rodas5P(autodiff = AutoEnzyme(mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const)), + dense_errors = true) #@test sim.𝒪est[:final]≈5 atol=testTol #-- observed order > 6 @test sim.𝒪est[:L2]≈5 atol=testTol - sol = solve(prob, Rodas5Pr()) + sim = test_convergence(dts, prob, + Rodas5P(autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const), + linsolve = LinearSolve.KrylovJL()), + dense_errors = true) + #@test sim.𝒪est[:final]≈5 atol=testTol #-- observed order > 6 + @test sim.𝒪est[:L2]≈5 atol=testTol + + sim = test_convergence(dts, prob, + Rodas5P(autodiff = AutoEnzyme( + mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const), + linsolve = LinearSolve.KrylovJL_GMRES()), + dense_errors = true) + #@test sim.𝒪est[:final]≈5 atol=testTol #-- observed order > 6 + @test sim.𝒪est[:L2]≈5 atol=testTol + + sol = solve(prob, + Rodas5P(autodiff = AutoEnzyme(mode = set_runtime_activity(Enzyme.Forward), + function_annotation = Enzyme.Const))) @test length(sol) < 20 + prob = ODEProblem((u, p, t) -> 0.9u, 0.1, (0.0, 1.0)) @test_nowarn solve(prob, Rosenbrock23(autodiff = AutoFiniteDiff())) + @test_nowarn solve(prob, + Rosenbrock23(autodiff = AutoEnzyme(mode = set_runtime_activity(Enzyme.Forward), + function_annotation = Enzyme.Const))) end @testset "Convergence with time-dependent matrix-free Jacobian" begin @@ -789,13 +910,11 @@ end else @inferred(solve(prob, alg; dt = 0.1)) end - @test sol.alg === alg alg = T(; autodiff = AutoFiniteDiff(; fdtype = Val(:central))) sol = if alg isa OrdinaryDiffEqRosenbrock.OrdinaryDiffEqRosenbrockAdaptiveAlgorithm @inferred(solve(prob, alg)) else @inferred(solve(prob, alg; dt = 0.1)) end - @test sol.alg === alg end end diff --git a/test/integrators/ode_cache_tests.jl b/test/integrators/ode_cache_tests.jl index 1dc92cbfc0..1194c54256 100644 --- a/test/integrators/ode_cache_tests.jl +++ b/test/integrators/ode_cache_tests.jl @@ -1,5 +1,5 @@ using OrdinaryDiffEq, OrdinaryDiffEqCore, DiffEqBase, Test, ADTypes -using Random, SparseDiffTools +using Random using OrdinaryDiffEqDefault using ElasticArrays, LinearSolve Random.seed!(213) diff --git a/test/integrators/resize_tests.jl b/test/integrators/resize_tests.jl index 1b2933d1a9..94fe53170a 100644 --- a/test/integrators/resize_tests.jl +++ b/test/integrators/resize_tests.jl @@ -1,6 +1,8 @@ -using OrdinaryDiffEq, Test, ADTypes +using OrdinaryDiffEq, Test, ADTypes, SparseMatrixColorings, DiffEqBase, ForwardDiff, SciMLBase, LinearSolve +import OrdinaryDiffEqDifferentiation.DI + f(du, u, p, t) = du .= u -prob = ODEProblem(f, [1.0], (0.0, 1.0)) +prob = ODEProblem{true, SciMLBase.FullSpecialize}(f, [1.0], (0.0, 1.0)) i = init(prob, Tsit5()) resize!(i, 5) @@ -31,11 +33,11 @@ resize!(i, 5) @test size(i.cache.nlsolver.cache.J) == (5, 5) @test size(i.cache.nlsolver.cache.W) == (5, 5) @test length(i.cache.nlsolver.cache.du1) == 5 -@test length(i.cache.nlsolver.cache.jac_config.fx) == 5 -@test length(i.cache.nlsolver.cache.jac_config.dx) == 5 -@test length(i.cache.nlsolver.cache.jac_config.t) == 5 -@test length(i.cache.nlsolver.cache.jac_config.p) == 5 @test length(i.cache.nlsolver.cache.weight) == 5 +@test all(size(DI.jacobian( + (du, u) -> (i.f(du, u, nothing, nothing)), rand(5), i.cache.nlsolver.cache.jac_config[1], + AutoForwardDiff(tag = ForwardDiff.Tag(DiffEqBase.OrdinaryDiffEqTag(), Float64)), rand(5))) .== + 5) solve!(i) i = init(prob, ImplicitEuler(; autodiff = AutoFiniteDiff())) @@ -54,10 +56,10 @@ resize!(i, 5) @test size(i.cache.nlsolver.cache.J) == (5, 5) @test size(i.cache.nlsolver.cache.W) == (5, 5) @test length(i.cache.nlsolver.cache.du1) == 5 -@test length(i.cache.nlsolver.cache.jac_config.x1) == 5 -@test length(i.cache.nlsolver.cache.jac_config.fx) == 5 -@test length(i.cache.nlsolver.cache.jac_config.fx1) == 5 @test length(i.cache.nlsolver.cache.weight) == 5 +@test all(size(DI.jacobian( + (du, u) -> (i.f(du, u, nothing, nothing)), rand(5), i.cache.nlsolver.cache.jac_config[1], + AutoFiniteDiff(), rand(5))) .== 5) solve!(i) i = init(prob, Rosenbrock23()) @@ -77,10 +79,29 @@ resize!(i, 5) @test size(i.cache.J) == (5, 5) @test size(i.cache.W) == (5, 5) @test length(i.cache.linsolve_tmp) == 5 -@test length(i.cache.jac_config.fx) == 5 -@test length(i.cache.jac_config.dx) == 5 -@test length(i.cache.jac_config.t) == 5 -@test length(i.cache.jac_config.p) == 5 +@test all(size(DI.jacobian( + (du, u) -> (i.f(du, u, nothing, nothing)), rand(5), i.cache.jac_config[1], + AutoForwardDiff(tag = ForwardDiff.Tag(DiffEqBase.OrdinaryDiffEqTag(), Float64)), rand(5))) .== + 5) +solve!(i) + +i = init(prob, Rosenbrock23(autodiff = AutoForwardDiff(), linsolve = KrylovJL_GMRES())) +resize!(i, 5) +@test length(i.cache.u) == 5 +@test length(i.cache.uprev) == 5 +@test length(i.cache.k₁) == 5 +@test length(i.cache.k₂) == 5 +@test length(i.cache.k₃) == 5 +@test length(i.cache.du1) == 5 +@test length(i.cache.du2) == 5 +@test length(i.cache.f₁) == 5 +@test length(i.cache.fsalfirst) == 5 +@test length(i.cache.fsallast) == 5 +@test length(i.cache.dT) == 5 +@test length(i.cache.tmp) == 5 +@test size(i.cache.J) == (5, 5) +@test size(i.cache.W) == (5, 5) +@test length(i.cache.linsolve_tmp) == 5 solve!(i) i = init(prob, Rosenbrock23(; autodiff = AutoFiniteDiff())) @@ -100,9 +121,9 @@ resize!(i, 5) @test size(i.cache.J) == (5, 5) @test size(i.cache.W) == (5, 5) @test length(i.cache.linsolve_tmp) == 5 -@test length(i.cache.jac_config.x1) == 5 -@test length(i.cache.jac_config.fx) == 5 -@test length(i.cache.jac_config.fx1) == 5 +@test all(size(DI.jacobian( + (du, u) -> (i.f(du, u, nothing, nothing)), rand(5), i.cache.jac_config[1], + AutoFiniteDiff(), rand(5))) .== 5) solve!(i) function f(du, u, p, t) diff --git a/test/interface/autodiff_error_tests.jl b/test/interface/autodiff_error_tests.jl index 5e4bdbd258..e3ffe24862 100644 --- a/test/interface/autodiff_error_tests.jl +++ b/test/interface/autodiff_error_tests.jl @@ -46,7 +46,7 @@ prob = ODEProblem(lorenz2!, u0, tspan) @test_throws OrdinaryDiffEqDifferentiation.FirstAutodiffTgradError solve( prob, Rosenbrock23()) -## Test that nothing is using duals when autodiff=false +## Test that nothing is using duals when autodiff=AutoFiniteDiff() ## https://discourse.julialang.org/t/rodas4-using-dual-number-for-time-with-autodiff-false/98256 for alg in [ diff --git a/test/interface/autosparse_detection_tests.jl b/test/interface/autosparse_detection_tests.jl new file mode 100644 index 0000000000..84a9d89dca --- /dev/null +++ b/test/interface/autosparse_detection_tests.jl @@ -0,0 +1,18 @@ +using Test, OrdinaryDiffEq, LinearSolve, ADTypes, ForwardDiff, SparseConnectivityTracer, + SparseMatrixColorings +import ODEProblemLibrary: prob_ode_2Dlinear + +ad = AutoSparse(AutoForwardDiff(), sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = GreedyColoringAlgorithm()) + +prob = prob_ode_2Dlinear + +@test_nowarn solve(prob, Rodas5P(autodiff = ad)) + +@test_nowarn solve(prob, Rodas5P(autodiff = ad, linsolve = LinearSolve.KrylovJL_GMRES())) + +# Test that no dense matrices are made sparse +diag_prob = ODEProblem((du, u, p, t) -> du .= -1.0 .* u, rand(Int(1e7)), (0, 1.0)) + +@test_nowarn solve(diag_prob, Rodas5P(autodiff = ad, linsolve = LinearSolve.KrylovJL_GMRES())) + diff --git a/test/interface/nojac.jl b/test/interface/nojac.jl index 2888343e0f..50644be664 100644 --- a/test/interface/nojac.jl +++ b/test/interface/nojac.jl @@ -48,8 +48,8 @@ nojac = @allocated init(prob_ode_brusselator_2d, save_everystep = false) jac = @allocated init(prob_ode_brusselator_2d, TRBDF2(), save_everystep = false) @test jac / nojac > 50 -@test integ1.cache.nlsolver.cache.jac_config !== nothing -@test integ2.cache.nlsolver.cache.jac_config === nothing +@test integ1.cache.nlsolver.cache.jac_config !== (nothing, nothing) +@test integ2.cache.nlsolver.cache.jac_config === (nothing, nothing) ## Test that no Jac Config is created @@ -246,14 +246,14 @@ u0[17] = 0.007 prob = ODEProblem(ODEFunction(pollu, jac = fjac), u0, (0.0, 60.0)) integ = init(prob, Rosenbrock23(), abstol = 1e-6, reltol = 1e-6) -@test integ.cache.jac_config === nothing +@test integ.cache.jac_config === (nothing, nothing) integ = init(prob, Rosenbrock23(linsolve = SimpleLUFactorization()), abstol = 1e-6, reltol = 1e-6) -@test integ.cache.jac_config === nothing +@test integ.cache.jac_config === (nothing, nothing) integ = init(prob, Rosenbrock23(linsolve = GenericLUFactorization()), abstol = 1e-6, reltol = 1e-6) -@test integ.cache.jac_config === nothing +@test integ.cache.jac_config === (nothing, nothing) integ = init(prob, Rosenbrock23(linsolve = RFLUFactorization(), autodiff = AutoForwardDiff(chunksize = 3)), abstol = 1e-6, reltol = 1e-6) -@test integ.cache.jac_config === nothing +@test integ.cache.jac_config === (nothing, nothing) diff --git a/test/interface/sparsediff_tests.jl b/test/interface/sparsediff_tests.jl index 10304b5b57..ac97110a91 100644 --- a/test/interface/sparsediff_tests.jl +++ b/test/interface/sparsediff_tests.jl @@ -2,7 +2,9 @@ using Test using OrdinaryDiffEq using SparseArrays using LinearAlgebra +using LinearSolve using ADTypes +using Enzyme ## in-place #https://github.com/JuliaDiffEq/SparseDiffTools.jl/blob/master/test/test_integration.jl @@ -52,10 +54,11 @@ for f in [f_oop, f_ip] odefun_std = ODEFunction(f) prob_std = ODEProblem(odefun_std, u0, tspan) - for ad in [AutoForwardDiff(), AutoFiniteDiff()] - for Solver in [Rodas5, Rosenbrock23, Trapezoid, KenCarp4] + for ad in [AutoForwardDiff(), AutoFiniteDiff(), + AutoEnzyme(mode = Enzyme.Forward, function_annotation = Enzyme.Const)], linsolve in [nothing, LinearSolve.KrylovJL_GMRES()] + for Solver in [Rodas5, Rosenbrock23, Trapezoid, KenCarp4, FBDF] for tol in [nothing, 1e-10] - sol_std = solve(prob_std, Solver(autodiff = ad), reltol = tol, abstol = tol) + sol_std = solve(prob_std, Solver(autodiff = ad, linsolve = linsolve), reltol = tol, abstol = tol) @test sol_std.retcode == ReturnCode.Success for (i, prob) in enumerate(map(f -> ODEProblem(f, u0, tspan), [ @@ -65,7 +68,7 @@ for f in [f_oop, f_ip] ODEFunction(f, colorvec = colors, sparsity = jac_sp) ])) - sol = solve(prob, Solver(autodiff = ad), reltol = tol, abstol = tol) + sol = solve(prob, Solver(autodiff = ad, linsolve = linsolve), reltol = tol, abstol = tol) @test sol.retcode == ReturnCode.Success if tol != nothing @test sol_std.u[end]≈sol.u[end] atol=tol @@ -78,3 +81,4 @@ for f in [f_oop, f_ip] end end end + diff --git a/test/interface/stats_tests.jl b/test/interface/stats_tests.jl index 2c5bf18285..90b672d8bc 100644 --- a/test/interface/stats_tests.jl +++ b/test/interface/stats_tests.jl @@ -23,10 +23,10 @@ probip = ODEProblem(g, u0, tspan) @test x[] == sol.stats.nf end @testset "$alg" for alg in [Rodas5P, KenCarp4] - @testset "$kwargs" for kwargs in [(autodiff = AutoForwardDiff(),), - (autodiff = AutoFiniteDiff(fdtype = Val{:forward}()),), - (autodiff = AutoFiniteDiff(fdtype = Val{:central}()),), - (autodiff = AutoFiniteDiff(fdtype = Val{:complex}()),)] + @testset "$kwargs" for kwargs in [(autodiff = AutoForwardDiff(),)] + #(autodiff = AutoFiniteDiff(fdtype = Val{:forward}()),), + #(autodiff = AutoFiniteDiff(fdtype = Val{:central}()),), + #(autodiff = AutoFiniteDiff(fdtype = Val{:complex}()),)] x[] = 0 sol = solve(prob, alg(; kwargs...)) @test x[] == sol.stats.nf diff --git a/test/interface/stiffness_detection_test.jl b/test/interface/stiffness_detection_test.jl index 503f0bb7a9..82a4b95e80 100644 --- a/test/interface/stiffness_detection_test.jl +++ b/test/interface/stiffness_detection_test.jl @@ -27,7 +27,7 @@ end is_switching_fb(sol) = all(i -> count(isequal(i), sol.alg_choice[2:end]) > 5, (1, 2)) for (i, prob) in enumerate(probArr) println(i) - sol = @test_nowarn solve(prob, AutoTsit5(Rosenbrock23(autodiff = AutoFiniteDiff())), + sol = solve(prob, AutoTsit5(Rosenbrock23(autodiff = AutoFiniteDiff())), maxiters = 1000) @test is_switching_fb(sol) alg = AutoTsit5(Rodas5(); maxstiffstep = 5, maxnonstiffstep = 5, stiffalgfirst = true) diff --git a/test/interface/wprototype_tests.jl b/test/interface/wprototype_tests.jl index 29a564875e..0fada74956 100644 --- a/test/interface/wprototype_tests.jl +++ b/test/interface/wprototype_tests.jl @@ -49,6 +49,12 @@ for prob in (prob_ode_vanderpol_stiff,) sol_W = solve(prob_W, alg) rtol = 1e-2 + + @test prob_J.f.sparsity.A == prob_W.f.sparsity.A + + @test all(isapprox.(sol_J.t, sol_W.t; rtol)) + @test all(isapprox.(sol_J.u, sol_W.u; rtol)) + @test all(isapprox.(sol_J.t, sol.t; rtol)) @test all(isapprox.(sol_J.u, sol.u; rtol)) @test all(isapprox.(sol_W.t, sol.t; rtol)) diff --git a/test/regression/time_derivative_test.jl b/test/regression/time_derivative_test.jl index 99c9dc0fc9..57d363b6a8 100644 --- a/test/regression/time_derivative_test.jl +++ b/test/regression/time_derivative_test.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEq, StaticArrays, Test +using OrdinaryDiffEq, StaticArrays, Test, ADTypes, Enzyme function time_derivative(du, u, p, t) du[1] = -t @@ -29,26 +29,29 @@ for (ff_time_derivative, u0) in ( prob = ODEProblem(ff_time_derivative, u0, tspan) - for _autodiff in (true, false) + for _autodiff in (AutoForwardDiff(), AutoFiniteDiff(), + AutoEnzyme(mode = Enzyme.Forward, function_annotation = Enzyme.Const)) @info "autodiff=$(_autodiff)" + prec = !(_autodiff == AutoFiniteDiff()) + @show Rosenbrock23 sol = solve(prob, Rosenbrock32(autodiff = _autodiff), reltol = 1e-9, abstol = 1e-9) - @test sol.errors[:final] < 1e-5 * 10^(!_autodiff) + @test sol.errors[:final] < 1e-5 * 10^(!prec) sol = solve(prob, Rosenbrock23(autodiff = _autodiff), reltol = 1e-9, abstol = 1e-9) - @test sol.errors[:final] < 1e-10 * 10_000^(!_autodiff) + @test sol.errors[:final] < 1e-10 * 10_000^(!prec) @show Rodas4 sol = solve(prob, Rodas4(autodiff = _autodiff), reltol = 1e-9, abstol = 1e-9) - @test sol.errors[:final] < 1e-10 * 100_000^(!_autodiff) + @test sol.errors[:final] < 1e-10 * 100_000^(!prec) @show Rodas5 sol = solve(prob, Rodas5(autodiff = _autodiff), reltol = 1e-9, abstol = 1e-9) - @test sol.errors[:final] < 1e-10 * 1_000_000_000^(!_autodiff) + @test sol.errors[:final] < 1e-10 * 1_000_000_000^(!prec) @show Veldd4 sol = solve(prob, Veldd4(autodiff = _autodiff), reltol = 1e-9, abstol = 1e-9) - @test sol.errors[:final] < 1e-10 * 100_000^(!_autodiff) + @test sol.errors[:final] < 1e-10 * 100_000^(!prec) @show KenCarp3 sol = solve(prob, KenCarp3(autodiff = _autodiff), reltol = 1e-12, abstol = 1e-12) diff --git a/test/runtests.jl b/test/runtests.jl index ddd8fc54ab..af6d4c4511 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -65,6 +65,7 @@ end @time @safetestset "Linear Solver Tests" include("interface/linear_solver_test.jl") @time @safetestset "Linear Solver Split ODE Tests" include("interface/linear_solver_split_ode_test.jl") @time @safetestset "Sparse Diff Tests" include("interface/sparsediff_tests.jl") + @time @safetestset "AutoSparse Detection Tests" include("interface/autosparse_detection_tests.jl") @time @safetestset "Enum Tests" include("interface/enums.jl") @time @safetestset "CheckInit Tests" include("interface/checkinit_tests.jl") @time @safetestset "Get du Tests" include("interface/get_du.jl")