Skip to content

Commit 6dddd65

Browse files
pianpwkpytorchmergebot
authored andcommitted
[dynamic shapes] add sym_and, sym_or (pytorch#150456)
This has been pretty helpful for the size-oblivious rewrite. Wanted the variadic args version to avoid `sym_or(a, sym_or(b, sym_or(c, d)))` in favor of `sym_or(a, b, c, d)`. Happy to change this to ban the 1-arg version. This is better than plain and/or because the whole symbolic expression gets preserved, and if we guard on it or defer as a runtime assert, we preserve all branches. Pull Request resolved: pytorch#150456 Approved by: https://github.com/laithsakka
1 parent 785495e commit 6dddd65

File tree

4 files changed

+57
-11
lines changed

4 files changed

+57
-11
lines changed

docs/source/fx.experimental.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ torch.fx.experimental.symbolic_shapes
4343
guard_or_true
4444
guard_or_false
4545
guard_size_oblivious
46+
sym_and
4647
sym_eq
48+
sym_or
4749
constrain_range
4850
constrain_unify
4951
canonicalize_bool_expr

test/export/test_export.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6990,6 +6990,29 @@ def forward(self, a, b, mode):
69906990
_ = exported.module()(torch.randn(4, 4), torch.randn(4), "floor")
69916991
self.assertTrue(torch.allclose(exported.module()(*inps), foo(*inps)))
69926992

6993+
def test_sym_or_sym_and(self):
6994+
from torch.fx.experimental.symbolic_shapes import sym_and, sym_or
6995+
6996+
class Foo(torch.nn.Module):
6997+
def forward(self, xs):
6998+
u0, u1, u2 = xs.tolist()
6999+
torch._check(sym_or(u0 == 2, u0 == 4, u0 == 6))
7000+
torch._check(sym_and(u1 >= 4, u1 <= 8, u2 == 5))
7001+
return u0 + u1 + u2
7002+
7003+
ep = export(Foo(), (torch.tensor([2, 6, 5]),), strict=False)
7004+
ep.module()(torch.tensor([2, 6, 5]))
7005+
ep.module()(torch.tensor([4, 7, 5]))
7006+
ep.module()(torch.tensor([6, 5, 5]))
7007+
with self.assertRaisesRegex(
7008+
RuntimeError, r".* expression Eq\(u0, 2\) \| Eq\(u0, 4\) \| Eq\(u0, 6\) .*"
7009+
):
7010+
ep.module()(torch.tensor([3, 6, 5]))
7011+
with self.assertRaisesRegex(
7012+
RuntimeError, r".* expression Eq\(u2, 5\) & \(4 <= u1\) & \(u1 <= 8\) .*"
7013+
):
7014+
ep.module()(torch.tensor([6, 6, 6]))
7015+
69937016
def test_redundant_assert_max_upper_bound(self):
69947017
class M(torch.nn.Module):
69957018
def forward(self, x):

torch/_export/serde/serialize.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ def _reverse_map(d: dict[Any, Enum]):
183183
operator.gt,
184184
operator.neg,
185185
operator.pos,
186+
operator.and_,
187+
operator.or_,
186188
math.trunc,
187189
torch.sym_not,
188190
operator.mul,

torch/fx/experimental/symbolic_shapes.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ class PendingUnbackedSymbolNotFound(RuntimeError):
143143
"CURRENT_NODE_KEY",
144144
"has_free_symbols",
145145
"has_free_unbacked_symbols",
146+
"sym_and",
146147
"sym_eq",
148+
"sym_or",
147149
"SymbolicContext",
148150
"StatelessSymbolicContext",
149151
"StatefulSymbolicContext",
@@ -1299,17 +1301,19 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool:
12991301
return result
13001302

13011303

1302-
# When a or b is evaluated, a is evaluated eagerly first then b. This causes
1303-
# a data dependent error for an expression “if u0==1 or True”. or over guarding for
1304-
# “if s0==1 or True”.
1305-
1306-
# On the other hand, when we use operator.or_, then dynamo will generate
1307-
# a sympy expression Sympy.Or(u0==1, True) without evaluating the args first.
1308-
1309-
# When the whole expression is passed to evaluation in that case, we do not throw a
1310-
# data dependent error or guard because we can statically know the result is True
1311-
# before unpacking the symbols.
1312-
sym_or = operator.or_
1304+
def sym_and(
1305+
x: Union[bool, SymBool], *others: Union[bool, SymBool]
1306+
) -> Union[bool, SymBool]:
1307+
"""
1308+
and, but for symbolic expressions, without bool casting.
1309+
"""
1310+
assert isinstance(x, (bool, SymBool))
1311+
if len(others) == 0:
1312+
return x
1313+
for y in others:
1314+
assert isinstance(y, (bool, SymBool))
1315+
x = operator.and_(x, y)
1316+
return x
13131317

13141318

13151319
def sym_eq(x: _T, y: _T) -> Union[bool, SymBool]:
@@ -1329,6 +1333,21 @@ def sym_eq(x: _T, y: _T) -> Union[bool, SymBool]:
13291333
raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}")
13301334

13311335

1336+
def sym_or(
1337+
x: Union[bool, SymBool], *others: Union[bool, SymBool]
1338+
) -> Union[bool, SymBool]:
1339+
"""
1340+
or, but for symbolic expressions, without bool casting.
1341+
"""
1342+
assert isinstance(x, (bool, SymBool))
1343+
if len(others) == 0:
1344+
return x
1345+
for y in others:
1346+
assert isinstance(y, (bool, SymBool))
1347+
x = operator.or_(x, y)
1348+
return x
1349+
1350+
13321351
def guard_scalar(
13331352
a: Union[SymBool, SymInt, SymFloat, int, bool, float]
13341353
) -> Union[bool, int, float]:

0 commit comments

Comments
 (0)