diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 79a88f18d..01dabe7a2 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -1,7 +1,7 @@ module LinearSolveForwardDiffExt using LinearSolve -using LinearSolve: SciMLLinearSolveAlgorithm +using LinearSolve: SciMLLinearSolveAlgorithm, __init using LinearAlgebra using ForwardDiff using ForwardDiff: Dual, Partials @@ -121,8 +121,17 @@ function linearsolve_dual_solution(u::AbstractArray, partials, zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1]))) end -function SciMLBase.init( - prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm, +function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) + return __dual_init(prob, alg, args...; kwargs...) +end + +# Opt out for GenericLUFactorization +function SciMLBase.init(prob::DualAbstractLinearProblem, alg::GenericLUFactorization, args...; kwargs...) + return __init(prob,alg, args...; kwargs...) +end + +function __dual_init( + prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm, args...; alias = LinearAliasSpecifier(), abstol = LinearSolve.default_tol(real(eltype(prob.b))), diff --git a/src/common.jl b/src/common.jl index f447df8e7..a40064e51 100644 --- a/src/common.jl +++ b/src/common.jl @@ -137,7 +137,11 @@ function __init_u0_from_Ab(A, b) end __init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltype(b)}) -function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, +function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) + __init(prob, alg, args...; kwargs...) +end + +function __init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; alias = LinearAliasSpecifier(), abstol = default_tol(real(eltype(prob.b))), diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 3d4a035f3..b7710f9de 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -87,13 +87,6 @@ backslash_x_p = A \ new_b @test ≈(x_p, backslash_x_p, rtol = 1e-9) # Nested Duals -function h(p) - (A = [p[1] p[2]+1 p[2]^3; - 3*p[1] p[1]+5 p[2] * p[1]-4; - p[2]^2 9*p[1] p[2]], - b = [p[1] + 1, p[2] * 2, p[1]^2]) -end - A, b = h([ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 1.0, 0.0), ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 0.0, 1.0)]) @@ -193,3 +186,10 @@ overload_x_p = solve(prob, UMFPACKFactorization()) backslash_x_p = A \ b @test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) + + +# Test that GenericLU doesn't create a DualLinearCache +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +prob = LinearProblem(A, b) +@test init(prob, GenericLUFactorization()) isa LinearSolve.LinearCache