Skip to content

Commit 6b36613

Browse files
authored
Merge pull request #3099 from jsiirola/immutable-expr-args
Enforce expression immutability in `expr.args`
2 parents 4d1a4ed + 0d00410 commit 6b36613

4 files changed

Lines changed: 71 additions & 87 deletions

File tree

pyomo/core/expr/numeric_expr.py

Lines changed: 48 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,13 +1162,30 @@ def nargs(self):
11621162

11631163
@property
11641164
def args(self):
1165-
if len(self._args_) != self._nargs:
1166-
self._args_ = self._args_[: self._nargs]
1167-
return self._args_
1165+
# We unconditionally make a copy of the args to isolate the user
1166+
# from future possible updates to the underlying list
1167+
return self._args_[: self._nargs]
11681168

11691169
def getname(self, *args, **kwds):
11701170
return 'sum'
11711171

1172+
def _trunc_append(self, other):
1173+
_args = self._args_
1174+
if len(_args) > self._nargs:
1175+
_args = _args[: self._nargs]
1176+
_args.append(other)
1177+
return self.__class__(_args)
1178+
1179+
def _trunc_extend(self, other):
1180+
_args = self._args_
1181+
if len(_args) > self._nargs:
1182+
_args = _args[: self._nargs]
1183+
if len(other._args_) == other._nargs:
1184+
_args.extend(other._args_)
1185+
else:
1186+
_args.extend(other._args_[: other._nargs])
1187+
return self.__class__(_args)
1188+
11721189
def _apply_operation(self, result):
11731190
return sum(result)
11741191

@@ -1821,17 +1838,13 @@ def _add_native_monomial(a, b):
18211838
def _add_native_linear(a, b):
18221839
if not a:
18231840
return b
1824-
args = b.args
1825-
args.append(a)
1826-
return b.__class__(args)
1841+
return b._trunc_append(a)
18271842

18281843

18291844
def _add_native_sum(a, b):
18301845
if not a:
18311846
return b
1832-
args = b.args
1833-
args.append(a)
1834-
return b.__class__(args)
1847+
return b._trunc_append(a)
18351848

18361849

18371850
def _add_native_other(a, b):
@@ -1872,15 +1885,11 @@ def _add_npv_monomial(a, b):
18721885

18731886

18741887
def _add_npv_linear(a, b):
1875-
args = b.args
1876-
args.append(a)
1877-
return b.__class__(args)
1888+
return b._trunc_append(a)
18781889

18791890

18801891
def _add_npv_sum(a, b):
1881-
args = b.args
1882-
args.append(a)
1883-
return b.__class__(args)
1892+
return b._trunc_append(a)
18841893

18851894

18861895
def _add_npv_other(a, b):
@@ -1942,19 +1951,15 @@ def _add_param_linear(a, b):
19421951
a = a.value
19431952
if not a:
19441953
return b
1945-
args = b.args
1946-
args.append(a)
1947-
return b.__class__(args)
1954+
return b._trunc_append(a)
19481955

19491956

19501957
def _add_param_sum(a, b):
19511958
if a.is_constant():
19521959
a = value(a)
19531960
if not a:
19541961
return b
1955-
args = b.args
1956-
args.append(a)
1957-
return b.__class__(args)
1962+
return b._trunc_append(a)
19581963

19591964

19601965
def _add_param_other(a, b):
@@ -1999,15 +2004,11 @@ def _add_var_monomial(a, b):
19992004

20002005

20012006
def _add_var_linear(a, b):
2002-
args = b.args
2003-
args.append(MonomialTermExpression((1, a)))
2004-
return b.__class__(args)
2007+
return b._trunc_append(MonomialTermExpression((1, a)))
20052008

20062009

20072010
def _add_var_sum(a, b):
2008-
args = b.args
2009-
args.append(a)
2010-
return b.__class__(args)
2011+
return b._trunc_append(a)
20112012

20122013

20132014
def _add_var_other(a, b):
@@ -2046,15 +2047,11 @@ def _add_monomial_monomial(a, b):
20462047

20472048

20482049
def _add_monomial_linear(a, b):
2049-
args = b.args
2050-
args.append(a)
2051-
return b.__class__(args)
2050+
return b._trunc_append(a)
20522051

20532052

20542053
def _add_monomial_sum(a, b):
2055-
args = b.args
2056-
args.append(a)
2057-
return b.__class__(args)
2054+
return b._trunc_append(a)
20582055

20592056

20602057
def _add_monomial_other(a, b):
@@ -2069,49 +2066,35 @@ def _add_monomial_other(a, b):
20692066
def _add_linear_native(a, b):
20702067
if not b:
20712068
return a
2072-
args = a.args
2073-
args.append(b)
2074-
return a.__class__(args)
2069+
return a._trunc_append(b)
20752070

20762071

20772072
def _add_linear_npv(a, b):
2078-
args = a.args
2079-
args.append(b)
2080-
return a.__class__(args)
2073+
return a._trunc_append(b)
20812074

20822075

20832076
def _add_linear_param(a, b):
20842077
if b.is_constant():
20852078
b = b.value
20862079
if not b:
20872080
return a
2088-
args = a.args
2089-
args.append(b)
2090-
return a.__class__(args)
2081+
return a._trunc_append(b)
20912082

20922083

20932084
def _add_linear_var(a, b):
2094-
args = a.args
2095-
args.append(MonomialTermExpression((1, b)))
2096-
return a.__class__(args)
2085+
return a._trunc_append(MonomialTermExpression((1, b)))
20972086

20982087

20992088
def _add_linear_monomial(a, b):
2100-
args = a.args
2101-
args.append(b)
2102-
return a.__class__(args)
2089+
return a._trunc_append(b)
21032090

21042091

21052092
def _add_linear_linear(a, b):
2106-
args = a.args
2107-
args.extend(b.args)
2108-
return a.__class__(args)
2093+
return a._trunc_extend(b)
21092094

21102095

21112096
def _add_linear_sum(a, b):
2112-
args = b.args
2113-
args.append(a)
2114-
return b.__class__(args)
2097+
return b._trunc_append(a)
21152098

21162099

21172100
def _add_linear_other(a, b):
@@ -2126,55 +2109,39 @@ def _add_linear_other(a, b):
21262109
def _add_sum_native(a, b):
21272110
if not b:
21282111
return a
2129-
args = a.args
2130-
args.append(b)
2131-
return a.__class__(args)
2112+
return a._trunc_append(b)
21322113

21332114

21342115
def _add_sum_npv(a, b):
2135-
args = a.args
2136-
args.append(b)
2137-
return a.__class__(args)
2116+
return a._trunc_append(b)
21382117

21392118

21402119
def _add_sum_param(a, b):
21412120
if b.is_constant():
21422121
b = b.value
21432122
if not b:
21442123
return a
2145-
args = a.args
2146-
args.append(b)
2147-
return a.__class__(args)
2124+
return a._trunc_append(b)
21482125

21492126

21502127
def _add_sum_var(a, b):
2151-
args = a.args
2152-
args.append(b)
2153-
return a.__class__(args)
2128+
return a._trunc_append(b)
21542129

21552130

21562131
def _add_sum_monomial(a, b):
2157-
args = a.args
2158-
args.append(b)
2159-
return a.__class__(args)
2132+
return a._trunc_append(b)
21602133

21612134

21622135
def _add_sum_linear(a, b):
2163-
args = a.args
2164-
args.append(b)
2165-
return a.__class__(args)
2136+
return a._trunc_append(b)
21662137

21672138

21682139
def _add_sum_sum(a, b):
2169-
args = a.args
2170-
args.extend(b.args)
2171-
return a.__class__(args)
2140+
return a._trunc_extend(b)
21722141

21732142

21742143
def _add_sum_other(a, b):
2175-
args = a.args
2176-
args.append(b)
2177-
return a.__class__(args)
2144+
return a._trunc_append(b)
21782145

21792146

21802147
#
@@ -2213,9 +2180,7 @@ def _add_other_linear(a, b):
22132180

22142181

22152182
def _add_other_sum(a, b):
2216-
args = b.args
2217-
args.append(a)
2218-
return b.__class__(args)
2183+
return b._trunc_append(a)
22192184

22202185

22212186
def _add_other_other(a, b):
@@ -2628,8 +2593,8 @@ def _neg_var(a):
26282593

26292594

26302595
def _neg_monomial(a):
2631-
args = a.args
2632-
return MonomialTermExpression((-args[0], args[1]))
2596+
coef, var = a.args
2597+
return MonomialTermExpression((-coef, var))
26332598

26342599

26352600
def _neg_sum(a):

pyomo/core/tests/unit/test_derivs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,18 @@ def test_nested_named_expressions(self):
322322
self.assertAlmostEqual(derivs[m.y], pyo.value(symbolic[m.y]), tol + 3)
323323
self.assertAlmostEqual(derivs[m.y], approx_deriv(e, m.y), tol)
324324

325+
def test_linear_exprs_issue_3096(self):
326+
m = pyo.ConcreteModel()
327+
m.y1 = pyo.Var(initialize=10)
328+
m.y2 = pyo.Var(initialize=100)
329+
e = (m.y1 - 0.5) * (m.y1 - 0.5) + (m.y2 - 0.5) * (m.y2 - 0.5)
330+
derivs = reverse_ad(e)
331+
self.assertEqual(derivs[m.y1], 19)
332+
self.assertEqual(derivs[m.y2], 199)
333+
symbolic = reverse_sd(e)
334+
self.assertExpressionsEqual(symbolic[m.y1], m.y1 - 0.5 + m.y1 - 0.5)
335+
self.assertExpressionsEqual(symbolic[m.y2], m.y2 - 0.5 + m.y2 - 0.5)
336+
325337

326338
class TestDifferentiate(unittest.TestCase):
327339
@unittest.skipUnless(sympy_available, "test requires sympy")

pyomo/core/tests/unit/test_numeric_expr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,9 +1618,9 @@ def test_nestedProduct2(self):
16181618
),
16191619
)
16201620
# Verify shared args...
1621-
self.assertIsNot(e1._args_, e2._args_)
1622-
self.assertIs(e1._args_, e3._args_)
1623-
self.assertIs(e1._args_, e.arg(1)._args_)
1621+
self.assertIs(e1._args_, e2._args_)
1622+
self.assertIsNot(e1._args_, e3._args_)
1623+
self.assertIs(e1._args_, e.arg(0)._args_)
16241624
self.assertIs(e.arg(0).arg(0), e.arg(1).arg(0))
16251625
self.assertIs(e.arg(0).arg(1), e.arg(1).arg(1))
16261626

pyomo/core/tests/unit/test_numeric_expr_api.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,14 @@ def test_sum(self):
950950
f = e.create_node_with_local_data(e.args)
951951
self.assertIsNot(f, e)
952952
self.assertIs(type(f), type(e))
953-
self.assertIs(f.args, e.args)
953+
self.assertIsNot(f._args_, e._args_)
954+
self.assertIsNot(f.args, e.args)
955+
956+
f = e.create_node_with_local_data(e._args_)
957+
self.assertIsNot(f, e)
958+
self.assertIs(type(f), type(e))
959+
self.assertIs(f._args_, e._args_)
960+
self.assertIsNot(f.args, e.args)
954961

955962
f = e.create_node_with_local_data((m.x, 2, 3))
956963
self.assertIsNot(f, e)

0 commit comments

Comments
 (0)