Skip to content

Commit f7cb759

Browse files
Remove NLsolve dependency in favor of SimpleNonlinearSolve
Replace the direct NLsolve.jl dependency with SimpleNonlinearSolve.jl for the IIF methods (IIF1M, IIF2M, IIF1Mil). This follows the same approach taken in OrdinaryDiffEq.jl (PR #2081). Changes: - Remove NLsolve from dependencies and compat - Add SimpleNonlinearSolve to dependencies and compat - Update NLSOLVEJL_SETUP to use SimpleTrustRegion instead of NLsolve.nlsolve - Replace OnceDifferentiable wrapper with a simple function wrapper The implicit methods that use OrdinaryDiffEqNonlinearSolve (ImplicitEM, ISSEM, SKenCarp, etc.) are unaffected as they already use NonlinearSolve internally. Closes #642 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 8a7d6e1 commit f7cb759

File tree

3 files changed

+24
-7
lines changed

3 files changed

+24
-7
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ LevyArea = "2d8b4e74-eb68-11e8-0fb9-d5eb67b50637"
1919
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2020
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2121
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
22-
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
2322
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
2423
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
2524
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
@@ -28,6 +27,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2827
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2928
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
3029
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
30+
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
3131
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3232
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3333
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
@@ -49,7 +49,6 @@ LinearAlgebra = "1.6"
4949
Logging = "1.6"
5050
ModelingToolkit = "10"
5151
MuladdMacro = "0.2.1"
52-
NLsolve = "4"
5352
OrdinaryDiffEq = "6.87"
5453
OrdinaryDiffEqCore = "1.32.0"
5554
OrdinaryDiffEqDifferentiation = "1.9"
@@ -59,6 +58,7 @@ RecursiveArrayTools = "2, 3"
5958
Reexport = "0.2, 1.0"
6059
SciMLBase = "2.115"
6160
SciMLOperators = "0.2.9, 0.3, 0.4, 1"
61+
SimpleNonlinearSolve = "2"
6262
SparseArrays = "1.6"
6363
StaticArrays = "0.11, 0.12, 1.0"
6464
UnPack = "0.1, 1.0"

src/StochasticDiffEq.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import OrdinaryDiffEqCore: default_controller, isstandard, ispredictive,
2121

2222
using UnPack, RecursiveArrayTools, DataStructures
2323
using DiffEqNoiseProcess, Random, ArrayInterface
24-
using NLsolve, ForwardDiff, StaticArrays, MuladdMacro, FiniteDiff, Base.Threads
24+
using SimpleNonlinearSolve, ForwardDiff, StaticArrays, MuladdMacro, FiniteDiff, Base.Threads
2525
using Adapt
2626

2727
import DiffEqBase: ODE_DEFAULT_NORM, ODE_DEFAULT_ISOUTOFDOMAIN,

src/misc_utils.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,28 @@ end
3232
struct NLSOLVEJL_SETUP{CS, AD} end
3333
Base.@pure NLSOLVEJL_SETUP(; chunk_size = 0, autodiff = true) = NLSOLVEJL_SETUP{
3434
chunk_size, autodiff}()
35-
(::NLSOLVEJL_SETUP)(f, u0; kwargs...) = (res = NLsolve.nlsolve(f, u0; kwargs...); res.zero)
35+
36+
# Wrapper to store the function for use with SimpleNonlinearSolve
37+
struct IIFNLSolveFunc{F}
38+
f::F
39+
end
40+
41+
function (p::NLSOLVEJL_SETUP{CS, AD})(f_wrapper::IIFNLSolveFunc, u0; kwargs...) where {
42+
CS, AD}
43+
f = f_wrapper.f
44+
# Create a NonlinearProblem-compatible function
45+
# The IIF methods use f(resid, u) signature (in-place)
46+
nlf = NonlinearFunction{true}((resid, u, p) -> (f(resid, u); nothing))
47+
prob = NonlinearProblem(nlf, u0)
48+
ad = AD ? AutoForwardDiff() : AutoFiniteDiff()
49+
alg = SimpleTrustRegion(; autodiff = ad)
50+
sol = solve(prob, alg)
51+
return sol.u
52+
end
53+
3654
function (p::NLSOLVEJL_SETUP{CS, AD})(::Type{Val{:init}}, f, u0_prototype) where {CS, AD}
37-
AD ? autodiff = :forward : autodiff = :central
38-
OnceDifferentiable(f, u0_prototype, u0_prototype, autodiff,
39-
ForwardDiff.Chunk(determine_chunksize(u0_prototype, CS)))
55+
# Return a wrapper that stores the function
56+
IIFNLSolveFunc(f)
4057
end
4158

4259
get_chunksize(x) = 0

0 commit comments

Comments
 (0)