Skip to content

Commit 95036fb

Browse files
authored
[Nonlinear] fix splatting with a univariate operator (#2221)
1 parent 0214039 commit 95036fb

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

src/Nonlinear/parse.jl

+12-2
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ function _parse_expression(stack, data, expr, x, parent_index)
112112
if length(x.args) == 2 && !isexpr(x.args[2], :...)
113113
_parse_univariate_expression(stack, data, expr, x, parent_index)
114114
else
115+
# The call is either n-ary, or it is a splat, in which case we
116+
# cannot tell just yet whether the expression is unary or nary.
117+
# Punt to multivariate and try to recover later.
115118
_parse_multivariate_expression(stack, data, expr, x, parent_index)
116119
end
117120
elseif isexpr(x, :comparison)
@@ -177,8 +180,15 @@ function _parse_multivariate_expression(
177180
@assert isexpr(x, :call)
178181
id = get(data.operators.multivariate_operator_to_id, x.args[1], nothing)
179182
if id === nothing
180-
@assert x.args[1] in data.operators.comparison_operators
181-
_parse_inequality_expression(stack, data, expr, x, parent_index)
183+
if haskey(data.operators.univariate_operator_to_id, x.args[1])
184+
# It may also be a unary variate operator with splatting.
185+
_parse_univariate_expression(stack, data, expr, x, parent_index)
186+
elseif x.args[1] in data.operators.comparison_operators
187+
# Or it may be a binary (in)equality operator.
188+
_parse_inequality_expression(stack, data, expr, x, parent_index)
189+
else
190+
throw(MOI.UnsupportedNonlinearOperator(x.args[1]))
191+
end
182192
return
183193
end
184194
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, id, parent_index))

test/Nonlinear/Nonlinear.jl

+19
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,25 @@ function test_ListOfSupportedNonlinearOperators()
10841084
return
10851085
end
10861086

1087+
function test_parse_univariate_splatting()
1088+
model = MOI.Nonlinear.Model()
1089+
MOI.Nonlinear.register_operator(model, :f, 1, x -> 2x)
1090+
x = [MOI.VariableIndex(1)]
1091+
@test MOI.Nonlinear.parse_expression(model, :(f($x...))) ==
1092+
MOI.Nonlinear.parse_expression(model, :(f($(x[1]))))
1093+
return
1094+
end
1095+
1096+
function test_parse_unsupported_operator()
1097+
model = MOI.Nonlinear.Model()
1098+
x = [MOI.VariableIndex(1)]
1099+
@test_throws(
1100+
MOI.UnsupportedNonlinearOperator(:f),
1101+
MOI.Nonlinear.parse_expression(model, :(f($x...))),
1102+
)
1103+
return
1104+
end
1105+
10871106
end
10881107

10891108
TestNonlinear.runtests()

0 commit comments

Comments
 (0)