Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Nonlinear] add support for simplifying NonlinearFunction #2605

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Nonlinear/Nonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,6 @@ include("model.jl")
include("evaluator.jl")

include("ReverseAD/ReverseAD.jl")
include("SymbolicAD/SymbolicAD.jl")

end # module
225 changes: 225 additions & 0 deletions src/Nonlinear/SymbolicAD/SymbolicAD.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright (c) 2017: Miles Lubin and contributors
# Copyright (c) 2017: Google Inc.
#
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

module SymbolicAD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this called SymbolicAD ? It's not really doing AD, it's just simplifying

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My plan was to move lanl-ansi/MathOptSymbolicAD.jl#39 into here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense then


import MathOptInterface as MOI

"""
simplify(f)

Return a simplified version of the function `f`.

!!! warning
This function is not type stable by design.
"""
simplify(f) = f

function simplify(f::MOI.ScalarAffineFunction{T}) where {T}
f = MOI.Utilities.canonical(f)
if isempty(f.terms)
return f.constant
end
return f
end

function simplify(f::MOI.ScalarQuadraticFunction{T}) where {T}
f = MOI.Utilities.canonical(f)
if isempty(f.quadratic_terms)
return simplify(MOI.ScalarAffineFunction(f.affine_terms, f.constant))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling simplify will canonicalize again, we don't need to do the type unstable part of deciding whether it should be a constant.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this comment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we just did f = MOI.Utilities.canonical(f) so f.affine_terms are already canonicalized. Here, we create a SAF that we know is canonical and then pass it to simplify. The first thing the function will do is canonicalize again which is a bit wasteful

end
return f
end

function simplify(f::MOI.ScalarNonlinearFunction)
for i in 1:length(f.args)
f.args[i] = simplify(f.args[i])
end
return _eval_if_constant(simplify(Val(f.head), f))
end

function simplify(f::MOI.VectorAffineFunction{T}) where {T}
f = MOI.Utilities.canonical(f)
if isempty(f.terms)
return f.constants
end
return f
end

function simplify(f::MOI.VectorQuadraticFunction{T}) where {T}
f = MOI.Utilities.canonical(f)
if isempty(f.quadratic_terms)
return simplify(MOI.VectorAffineFunction(f.affine_terms, f.constants))
end
return f
end

function simplify(f::MOI.VectorNonlinearFunction)
return MOI.VectorNonlinearFunction(simplify.(f.rows))
end

# If a ScalarNonlinearFunction has only constant arguments, we should return
# the vaålue.

_isnum(::Any) = false

_isnum(::Union{Bool,Integer,Float64}) = true

function _eval_if_constant(f::MOI.ScalarNonlinearFunction)
if all(_isnum, f.args) && hasproperty(Base, f.head)
return getproperty(Base, f.head)(f.args...)
end
return f
end

_eval_if_constant(f) = f

_iszero(x::Any) = _isnum(x) && iszero(x)

_isone(x::Any) = _isnum(x) && isone(x)

"""
_isexpr(f::Any, head::Symbol[, n::Int])

Return `true` if `f` is a `ScalarNonlinearFunction` with head `head` and, if
specified, `n` arguments.
"""
_isexpr(::Any, ::Symbol, n::Int = 0) = false

_isexpr(f::MOI.ScalarNonlinearFunction, head::Symbol) = f.head == head

function _isexpr(f::MOI.ScalarNonlinearFunction, head::Symbol, n::Int)
return _isexpr(f, head) && length(f.args) == n
end

"""
simplify(::Val{head}, f::MOI.ScalarNonlinearFunction)

Return a simplified version of `f` where the head of `f` is `head`.

Implementing this method enables custom simplification rules for different
operators without needing a giant switch statement.
"""
simplify(::Val, f::MOI.ScalarNonlinearFunction) = f

function simplify(::Val{:*}, f::MOI.ScalarNonlinearFunction)
new_args = Any[]
first_constant = 0
for arg in f.args
if _isexpr(arg, :*)
# If the child is a :*, lift its arguments to the parent
append!(new_args, arg.args)
elseif _iszero(arg)
# If any argument is zero, the entire expression must be false
return false
elseif _isone(arg)
# Skip any arguments that are one
elseif arg isa Real
# Collect all constant arguments into a single value
if first_constant == 0
push!(new_args, arg)
first_constant = length(new_args)
else
new_args[first_constant] *= arg
end
else
push!(new_args, arg)
end
end
if isempty(new_args)
return true
elseif length(new_args) == 1
return only(new_args)
end
return MOI.ScalarNonlinearFunction(:*, new_args)
end

function simplify(::Val{:+}, f::MOI.ScalarNonlinearFunction)
if length(f.args) == 1
# +(x) -> x
return only(f.args)
elseif length(f.args) == 2 && _isexpr(f.args[2], :-, 1)
# +(x, -y) -> -(x, y)
return MOI.ScalarNonlinearFunction(
:-,
Any[f.args[1], f.args[2].args[1]],
)
end
new_args = Any[]
first_constant = 0
for arg in f.args
if _isexpr(arg, :+)
# If a child is a :+, lift its arguments to the parent
append!(new_args, arg.args)
elseif _iszero(arg)
# Skip any zero arguments
elseif arg isa Real
# Collect all constant arguments into a single value
if first_constant == 0
push!(new_args, arg)
first_constant = length(new_args)
else
new_args[first_constant] += arg
end
else
push!(new_args, arg)
end
end
if isempty(new_args)
# +() -> false
return false
elseif length(new_args) == 1
# +(x) -> x
return only(new_args)
end
return MOI.ScalarNonlinearFunction(:+, new_args)
end

function simplify(::Val{:-}, f::MOI.ScalarNonlinearFunction)
if length(f.args) == 1
if _isexpr(f.args[1], :-, 1)
# -(-(x)) => x
return f.args[1].args[1]
end
elseif length(f.args) == 2
if _iszero(f.args[1])
# 0 - x => -x
return MOI.ScalarNonlinearFunction(:-, Any[f.args[2]])
elseif _iszero(f.args[2])
# x - 0 => x
return f.args[1]
elseif f.args[1] == f.args[2]
# x - x => 0
return false
elseif _isexpr(f.args[2], :-, 1)
# x - -(y) => x + y
return MOI.ScalarNonlinearFunction(
:+,
Any[f.args[1], f.args[2].args[1]],
)
end
end
return f
end

function simplify(::Val{:^}, f::MOI.ScalarNonlinearFunction)
if _iszero(f.args[2])
# x^0 => 1
return true
elseif _isone(f.args[2])
# x^1 => x
return f.args[1]
elseif _iszero(f.args[1])
# 0^x => 0
return false
elseif _isone(f.args[1])
# 1^x => 1
return true
end
return f
end

end # module
9 changes: 9 additions & 0 deletions src/Utilities/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,10 @@ function canonicalize!(
return f
end

function canonical(f::MOI.ScalarNonlinearFunction)::MOI.ScalarNonlinearFunction
return MOI.Nonlinear.SymbolicAD.simplify(f)
end

function canonicalize!(f::MOI.ScalarNonlinearFunction)
for (i, arg) in enumerate(f.args)
if !is_canonical(arg)
Expand All @@ -1080,6 +1084,11 @@ function canonicalize!(f::MOI.ScalarNonlinearFunction)
return f
end

function canonical(f::MOI.VectorNonlinearFunction)
rows = MOI.Nonlinear.SymbolicAD.simplify.(f.rows)
return MOI.VectorNonlinearFunction(rows)
end

function canonicalize!(f::MOI.VectorNonlinearFunction)
for (i, fi) in enumerate(f.rows)
f.rows[i] = canonicalize!(fi)
Expand Down
71 changes: 51 additions & 20 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -926,11 +926,23 @@ function _is_approx(x::AbstractArray, y::AbstractArray; kwargs...)
all(z -> _is_approx(z[1], z[2]; kwargs...), zip(x, y))
end

function _is_univariate_plus(f)
if f.head == :+ && length(f.args) == 1
return only(f.args) isa ScalarNonlinearFunction
end
return false
end

function Base.isapprox(
f::ScalarNonlinearFunction,
g::ScalarNonlinearFunction;
kwargs...,
)
if _is_univariate_plus(f)
return isapprox(only(f.args), g.args; kwargs...)
elseif _is_univariate_plus(g)
return isapprox(f, only(g.args); kwargs...)
end
if f.head != g.head || length(f.args) != length(g.args)
return false
end
Expand Down Expand Up @@ -1127,22 +1139,40 @@ _order(x::VariableIndex, y::Real, z::VariableIndex) = (y, x, z)
_order(x::VariableIndex, y::VariableIndex, z::Real) = (z, x, y)
_order(x, y, z) = nothing

_order_quad(x, y) = nothing
_order_quad(x::VariableIndex, y::VariableIndex) = (x, y)

function Base.convert(
::Type{ScalarQuadraticTerm{T}},
f::ScalarNonlinearFunction,
) where {T}
if f.head != :* || length(f.args) != 3
if f.head != :*
throw(InexactError(:convert, ScalarQuadraticTerm, f))
elseif length(f.args) == 2
# Deal with *(x, y)
ret_2 = _order_quad(f.args[1], f.args[2])
if ret_2 === nothing
throw(InexactError(:convert, ScalarQuadraticTerm, f))
end
coef = one(T)
if ret_2[1] == ret_2[2]
coef *= 2
end
return ScalarQuadraticTerm(coef, ret_2[1], ret_2[2])
elseif length(f.args) == 3
# *(constant, x, y)
ret = _order(f.args[1], f.args[2], f.args[3])
if ret === nothing
throw(InexactError(:convert, ScalarQuadraticTerm, f))
end
coef = convert(T, ret[1])
if ret[2] == ret[3]
coef *= 2
end
return ScalarQuadraticTerm(coef, ret[2], ret[3])
else
return throw(InexactError(:convert, ScalarQuadraticTerm, f))
end
ret = _order(f.args[1], f.args[2], f.args[3])
if ret === nothing
throw(InexactError(:convert, ScalarQuadraticTerm, f))
end
coef = convert(T, ret[1])
if ret[2] == ret[3]
coef *= 2
end
return ScalarQuadraticTerm(coef, ret[2], ret[3])
end

function _add_to_function(
Expand All @@ -1157,7 +1187,11 @@ function _add_to_function(
arg::ScalarNonlinearFunction,
) where {T}
if arg.head == :* && length(arg.args) == 2
push!(f.affine_terms, convert(ScalarAffineTerm{T}, arg))
if _order_quad(arg.args[1], arg.args[2]) === nothing
push!(f.affine_terms, convert(ScalarAffineTerm{T}, arg))
else
push!(f.quadratic_terms, convert(ScalarQuadraticTerm{T}, arg))
end
elseif arg.head == :* && length(arg.args) == 3
push!(f.quadratic_terms, convert(ScalarQuadraticTerm{T}, arg))
else
Expand All @@ -1174,15 +1208,12 @@ function Base.convert(
f::ScalarNonlinearFunction,
) where {T}
if f.head == :*
if length(f.args) == 2
quad_terms = ScalarQuadraticTerm{T}[]
affine_terms = [convert(ScalarAffineTerm{T}, f)]
return ScalarQuadraticFunction{T}(quad_terms, affine_terms, zero(T))
elseif length(f.args) == 3
quad_terms = [convert(ScalarQuadraticTerm{T}, f)]
affine_terms = ScalarAffineTerm{T}[]
return ScalarQuadraticFunction{T}(quad_terms, affine_terms, zero(T))
end
g = ScalarQuadraticFunction{T}(
ScalarQuadraticTerm{T}[],
ScalarAffineTerm{T}[],
zero(T),
)
return _add_to_function(g, f)
elseif f.head == :^ && length(f.args) == 2 && f.args[2] == 2
return convert(
ScalarQuadraticFunction{T},
Expand Down
9 changes: 3 additions & 6 deletions test/Bridges/Constraint/NormInfinityBridge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -647,9 +647,9 @@ function test_NormInfinity_VectorNonlinearFunction()
g = MOI.VectorNonlinearFunction([
MOI.ScalarNonlinearFunction(
:+,
Any[MOI.ScalarNonlinearFunction(:-, Any[v_sin]), u_p],
Any[MOI.ScalarNonlinearFunction(:-, Any[v_sin]), u],
),
MOI.ScalarNonlinearFunction(:+, Any[v_sin, u_p]),
MOI.ScalarNonlinearFunction(:+, Any[v_sin, u]),
])
@test ≈(MOI.get(inner, MOI.ConstraintFunction(), indices[1]), g)
h = MOI.VectorNonlinearFunction([
Expand Down Expand Up @@ -689,10 +689,7 @@ function test_NormOne_VectorNonlinearFunction()
u, v, w = inner_variables
v_sin = MOI.ScalarNonlinearFunction(:sin, Any[v])
g = MOI.VectorNonlinearFunction([
MOI.ScalarNonlinearFunction(
:-,
Any[MOI.ScalarNonlinearFunction(:+, Any[u]), 0.0+1.0*w],
),
MOI.ScalarNonlinearFunction(:-, Any[u, 0.0+1.0*w]),
MOI.ScalarNonlinearFunction(
:+,
Any[MOI.ScalarNonlinearFunction(:-, Any[v_sin]), w],
Expand Down
Loading
Loading