Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ LevyArea = "2d8b4e74-eb68-11e8-0fb9-d5eb67b50637"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
Expand All @@ -28,6 +27,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Expand All @@ -49,7 +49,6 @@ LinearAlgebra = "1.6"
Logging = "1.6"
ModelingToolkit = "10"
MuladdMacro = "0.2.1"
NLsolve = "4"
OrdinaryDiffEq = "6.87"
OrdinaryDiffEqCore = "1.32.0"
OrdinaryDiffEqDifferentiation = "1.9"
Expand All @@ -59,6 +58,7 @@ RecursiveArrayTools = "2, 3"
Reexport = "0.2, 1.0"
SciMLBase = "2.115"
SciMLOperators = "0.2.9, 0.3, 0.4, 1"
SimpleNonlinearSolve = "2"
SparseArrays = "1.6"
StaticArrays = "0.11, 0.12, 1.0"
UnPack = "0.1, 1.0"
Expand Down
2 changes: 1 addition & 1 deletion src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import OrdinaryDiffEqCore: default_controller, isstandard, ispredictive,

using UnPack, RecursiveArrayTools, DataStructures
using DiffEqNoiseProcess, Random, ArrayInterface
using NLsolve, ForwardDiff, StaticArrays, MuladdMacro, FiniteDiff, Base.Threads
using SimpleNonlinearSolve, ForwardDiff, StaticArrays, MuladdMacro, FiniteDiff, Base.Threads
using Adapt

import DiffEqBase: ODE_DEFAULT_NORM, ODE_DEFAULT_ISOUTOFDOMAIN,
Expand Down
25 changes: 21 additions & 4 deletions src/misc_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,28 @@ end
struct NLSOLVEJL_SETUP{CS, AD} end
Base.@pure NLSOLVEJL_SETUP(; chunk_size = 0, autodiff = true) = NLSOLVEJL_SETUP{
chunk_size, autodiff}()
(::NLSOLVEJL_SETUP)(f, u0; kwargs...) = (res = NLsolve.nlsolve(f, u0; kwargs...); res.zero)

# Wrapper to store the function for use with SimpleNonlinearSolve
struct IIFNLSolveFunc{F}
f::F
end

function (p::NLSOLVEJL_SETUP{CS, AD})(f_wrapper::IIFNLSolveFunc, u0; kwargs...) where {
CS, AD}
f = f_wrapper.f
# Create a NonlinearProblem-compatible function
# The IIF methods use f(resid, u) signature (in-place)
nlf = NonlinearFunction{true}((resid, u, p) -> (f(resid, u); nothing))
prob = NonlinearProblem(nlf, u0)
ad = AD ? AutoForwardDiff() : AutoFiniteDiff()
alg = SimpleTrustRegion(; autodiff = ad)
sol = solve(prob, alg)
return sol.u
end

function (p::NLSOLVEJL_SETUP{CS, AD})(::Type{Val{:init}}, f, u0_prototype) where {CS, AD}
AD ? autodiff = :forward : autodiff = :central
OnceDifferentiable(f, u0_prototype, u0_prototype, autodiff,
ForwardDiff.Chunk(determine_chunksize(u0_prototype, CS)))
# Return a wrapper that stores the function
IIFNLSolveFunc(f)
end

get_chunksize(x) = 0
Expand Down
Loading