Skip to content

Commit 40b1f7c

Browse files
Merge pull request #3090 from AayushSabharwal/as/promote-resid-prototype
fix: promote `resid_prototype` using tunables
2 parents 23b7b2e + 794a421 commit 40b1f7c

File tree

2 files changed

+47
-7
lines changed

2 files changed

+47
-7
lines changed

src/systems/nonlinear/nonlinearsystem.jl

+28-7
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ function SciMLBase.NonlinearFunction(sys::NonlinearSystem, args...; kwargs...)
295295
end
296296

297297
function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
298-
ps = parameters(sys), u0 = nothing;
298+
ps = parameters(sys), u0 = nothing, p = nothing;
299299
version = nothing,
300300
jac = false,
301301
eval_expression = false,
@@ -327,11 +327,22 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
327327

328328
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
329329

330+
if length(dvs) == length(equations(sys))
331+
resid_prototype = nothing
332+
else
333+
u0ElType = u0 === nothing ? Float64 : eltype(u0)
334+
if SciMLStructures.isscimlstructure(p)
335+
u0ElType = promote_type(
336+
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
337+
u0ElType)
338+
end
339+
resid_prototype = zeros(u0ElType, length(equations(sys)))
340+
end
341+
330342
NonlinearFunction{iip}(f,
331343
sys = sys,
332344
jac = _jac === nothing ? nothing : _jac,
333-
resid_prototype = length(dvs) == length(equations(sys)) ? nothing :
334-
zeros(length(equations(sys))),
345+
resid_prototype = resid_prototype,
335346
jac_prototype = sparse ?
336347
similar(calculate_jacobian(sys, sparse = sparse),
337348
Float64) : nothing,
@@ -355,7 +366,7 @@ variable and parameter vectors, respectively.
355366
struct NonlinearFunctionExpr{iip} end
356367

357368
function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
358-
ps = parameters(sys), u0 = nothing;
369+
ps = parameters(sys), u0 = nothing, p = nothing;
359370
version = nothing, tgrad = false,
360371
jac = false,
361372
linenumbers = false,
@@ -376,8 +387,18 @@ function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
376387
end
377388

378389
jp_expr = sparse ? :(similar($(get_jac(sys)[]), Float64)) : :nothing
379-
resid_expr = length(dvs) == length(equations(sys)) ? :nothing :
380-
:(zeros($(length(equations(sys)))))
390+
if length(dvs) == length(equations(sys))
391+
resid_expr = :nothing
392+
else
393+
u0ElType = u0 === nothing ? Float64 : eltype(u0)
394+
if SciMLStructures.isscimlstructure(p)
395+
u0ElType = promote_type(
396+
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
397+
u0ElType)
398+
end
399+
400+
resid_expr = :(zeros($u0ElType, $(length(equations(sys)))))
401+
end
381402
ex = quote
382403
f = $f
383404
jac = $_jac
@@ -412,7 +433,7 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, para
412433
check_eqs_u0(eqs, dvs, u0; kwargs...)
413434
end
414435

415-
f = constructor(sys, dvs, ps, u0; jac = jac, checkbounds = checkbounds,
436+
f = constructor(sys, dvs, ps, u0, p; jac = jac, checkbounds = checkbounds,
416437
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
417438
sparse = sparse, eval_expression = eval_expression, eval_module = eval_module,
418439
kwargs...)

test/nonlinearsystem.jl

+19
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ModelingToolkit: get_metadata
33
using DiffEqBase, SparseArrays
44
using Test
55
using NonlinearSolve
6+
using ForwardDiff
67
using ModelingToolkit: value
78
using ModelingToolkit: get_default_or_guess, MTKParameters
89

@@ -325,3 +326,21 @@ end
325326
prob = @test_nowarn NonlinearProblem(sys, nothing)
326327
@test_nowarn solve(prob)
327328
end
329+
330+
@testset "resid_prototype when system has no unknowns and an equation" begin
331+
@variables x
332+
@parameters p
333+
@named sys = NonlinearSystem([x ~ 1, x^2 - p ~ 0])
334+
for sys in [
335+
structural_simplify(sys, fully_determined = false),
336+
structural_simplify(sys, fully_determined = false, split = false)
337+
]
338+
@test length(equations(sys)) == 1
339+
@test length(unknowns(sys)) == 0
340+
T = typeof(ForwardDiff.Dual(1.0))
341+
prob = NonlinearProblem(sys, [], [p => ForwardDiff.Dual(1.0)]; check_length = false)
342+
@test prob.f(Float64[], prob.p) isa Vector{T}
343+
@test prob.f.resid_prototype isa Vector{T}
344+
@test_nowarn solve(prob)
345+
end
346+
end

0 commit comments

Comments
 (0)