Skip to content

Commit 38642bf

Browse files
authored
gh-126835: Move constant unaryop & binop folding to CFG (#129550)
1 parent d88677a commit 38642bf

File tree

6 files changed

+1058
-444
lines changed

6 files changed

+1058
-444
lines changed

Lib/test/test_ast/test_ast.py

+80-118
Original file line numberDiff line numberDiff line change
@@ -154,18 +154,17 @@ def test_optimization_levels__debug__(self):
154154
self.assertEqual(res.body[0].value.id, expected)
155155

156156
def test_optimization_levels_const_folding(self):
157-
folded = ('Expr', (1, 0, 1, 5), ('Constant', (1, 0, 1, 5), 3, None))
158-
not_folded = ('Expr', (1, 0, 1, 5),
159-
('BinOp', (1, 0, 1, 5),
160-
('Constant', (1, 0, 1, 1), 1, None),
161-
('Add',),
162-
('Constant', (1, 4, 1, 5), 2, None)))
157+
folded = ('Expr', (1, 0, 1, 6), ('Constant', (1, 0, 1, 6), (1, 2), None))
158+
not_folded = ('Expr', (1, 0, 1, 6),
159+
('Tuple', (1, 0, 1, 6),
160+
[('Constant', (1, 1, 1, 2), 1, None),
161+
('Constant', (1, 4, 1, 5), 2, None)], ('Load',)))
163162

164163
cases = [(-1, not_folded), (0, not_folded), (1, folded), (2, folded)]
165164
for (optval, expected) in cases:
166165
with self.subTest(optval=optval):
167-
tree1 = ast.parse("1 + 2", optimize=optval)
168-
tree2 = ast.parse(ast.parse("1 + 2"), optimize=optval)
166+
tree1 = ast.parse("(1, 2)", optimize=optval)
167+
tree2 = ast.parse(ast.parse("(1, 2)"), optimize=optval)
169168
for tree in [tree1, tree2]:
170169
res = to_tuple(tree.body[0])
171170
self.assertEqual(res, expected)
@@ -3089,27 +3088,6 @@ def test_cli_file_input(self):
30893088

30903089

30913090
class ASTOptimiziationTests(unittest.TestCase):
3092-
binop = {
3093-
"+": ast.Add(),
3094-
"-": ast.Sub(),
3095-
"*": ast.Mult(),
3096-
"/": ast.Div(),
3097-
"%": ast.Mod(),
3098-
"<<": ast.LShift(),
3099-
">>": ast.RShift(),
3100-
"|": ast.BitOr(),
3101-
"^": ast.BitXor(),
3102-
"&": ast.BitAnd(),
3103-
"//": ast.FloorDiv(),
3104-
"**": ast.Pow(),
3105-
}
3106-
3107-
unaryop = {
3108-
"~": ast.Invert(),
3109-
"+": ast.UAdd(),
3110-
"-": ast.USub(),
3111-
}
3112-
31133091
def wrap_expr(self, expr):
31143092
return ast.Module(body=[ast.Expr(value=expr)])
31153093

@@ -3141,83 +3119,6 @@ def assert_ast(self, code, non_optimized_target, optimized_target):
31413119
f"{ast.dump(optimized_tree)}",
31423120
)
31433121

3144-
def create_binop(self, operand, left=ast.Constant(1), right=ast.Constant(1)):
3145-
return ast.BinOp(left=left, op=self.binop[operand], right=right)
3146-
3147-
def test_folding_binop(self):
3148-
code = "1 %s 1"
3149-
operators = self.binop.keys()
3150-
3151-
for op in operators:
3152-
result_code = code % op
3153-
non_optimized_target = self.wrap_expr(self.create_binop(op))
3154-
optimized_target = self.wrap_expr(ast.Constant(value=eval(result_code)))
3155-
3156-
with self.subTest(
3157-
result_code=result_code,
3158-
non_optimized_target=non_optimized_target,
3159-
optimized_target=optimized_target
3160-
):
3161-
self.assert_ast(result_code, non_optimized_target, optimized_target)
3162-
3163-
# Multiplication of constant tuples must be folded
3164-
code = "(1,) * 3"
3165-
non_optimized_target = self.wrap_expr(self.create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
3166-
optimized_target = self.wrap_expr(ast.Constant(eval(code)))
3167-
3168-
self.assert_ast(code, non_optimized_target, optimized_target)
3169-
3170-
def test_folding_unaryop(self):
3171-
code = "%s1"
3172-
operators = self.unaryop.keys()
3173-
3174-
def create_unaryop(operand):
3175-
return ast.UnaryOp(op=self.unaryop[operand], operand=ast.Constant(1))
3176-
3177-
for op in operators:
3178-
result_code = code % op
3179-
non_optimized_target = self.wrap_expr(create_unaryop(op))
3180-
optimized_target = self.wrap_expr(ast.Constant(eval(result_code)))
3181-
3182-
with self.subTest(
3183-
result_code=result_code,
3184-
non_optimized_target=non_optimized_target,
3185-
optimized_target=optimized_target
3186-
):
3187-
self.assert_ast(result_code, non_optimized_target, optimized_target)
3188-
3189-
def test_folding_not(self):
3190-
code = "not (1 %s (1,))"
3191-
operators = {
3192-
"in": ast.In(),
3193-
"is": ast.Is(),
3194-
}
3195-
opt_operators = {
3196-
"is": ast.IsNot(),
3197-
"in": ast.NotIn(),
3198-
}
3199-
3200-
def create_notop(operand):
3201-
return ast.UnaryOp(op=ast.Not(), operand=ast.Compare(
3202-
left=ast.Constant(value=1),
3203-
ops=[operators[operand]],
3204-
comparators=[ast.Tuple(elts=[ast.Constant(value=1)])]
3205-
))
3206-
3207-
for op in operators.keys():
3208-
result_code = code % op
3209-
non_optimized_target = self.wrap_expr(create_notop(op))
3210-
optimized_target = self.wrap_expr(
3211-
ast.Compare(left=ast.Constant(1), ops=[opt_operators[op]], comparators=[ast.Constant(value=(1,))])
3212-
)
3213-
3214-
with self.subTest(
3215-
result_code=result_code,
3216-
non_optimized_target=non_optimized_target,
3217-
optimized_target=optimized_target
3218-
):
3219-
self.assert_ast(result_code, non_optimized_target, optimized_target)
3220-
32213122
def test_folding_format(self):
32223123
code = "'%s' % (a,)"
32233124

@@ -3247,9 +3148,9 @@ def test_folding_tuple(self):
32473148
self.assert_ast(code, non_optimized_target, optimized_target)
32483149

32493150
def test_folding_type_param_in_function_def(self):
3250-
code = "def foo[%s = 1 + 1](): pass"
3151+
code = "def foo[%s = (1, 2)](): pass"
32513152

3252-
unoptimized_binop = self.create_binop("+")
3153+
unoptimized_tuple = ast.Tuple(elts=[ast.Constant(1), ast.Constant(2)])
32533154
unoptimized_type_params = [
32543155
("T", "T", ast.TypeVar),
32553156
("**P", "P", ast.ParamSpec),
@@ -3263,23 +3164,23 @@ def test_folding_type_param_in_function_def(self):
32633164
name='foo',
32643165
args=ast.arguments(),
32653166
body=[ast.Pass()],
3266-
type_params=[type_param(name=name, default_value=ast.Constant(2))]
3167+
type_params=[type_param(name=name, default_value=ast.Constant((1, 2)))]
32673168
)
32683169
)
32693170
non_optimized_target = self.wrap_statement(
32703171
ast.FunctionDef(
32713172
name='foo',
32723173
args=ast.arguments(),
32733174
body=[ast.Pass()],
3274-
type_params=[type_param(name=name, default_value=unoptimized_binop)]
3175+
type_params=[type_param(name=name, default_value=unoptimized_tuple)]
32753176
)
32763177
)
32773178
self.assert_ast(result_code, non_optimized_target, optimized_target)
32783179

32793180
def test_folding_type_param_in_class_def(self):
3280-
code = "class foo[%s = 1 + 1]: pass"
3181+
code = "class foo[%s = (1, 2)]: pass"
32813182

3282-
unoptimized_binop = self.create_binop("+")
3183+
unoptimized_tuple = ast.Tuple(elts=[ast.Constant(1), ast.Constant(2)])
32833184
unoptimized_type_params = [
32843185
("T", "T", ast.TypeVar),
32853186
("**P", "P", ast.ParamSpec),
@@ -3292,22 +3193,22 @@ def test_folding_type_param_in_class_def(self):
32923193
ast.ClassDef(
32933194
name='foo',
32943195
body=[ast.Pass()],
3295-
type_params=[type_param(name=name, default_value=ast.Constant(2))]
3196+
type_params=[type_param(name=name, default_value=ast.Constant((1, 2)))]
32963197
)
32973198
)
32983199
non_optimized_target = self.wrap_statement(
32993200
ast.ClassDef(
33003201
name='foo',
33013202
body=[ast.Pass()],
3302-
type_params=[type_param(name=name, default_value=unoptimized_binop)]
3203+
type_params=[type_param(name=name, default_value=unoptimized_tuple)]
33033204
)
33043205
)
33053206
self.assert_ast(result_code, non_optimized_target, optimized_target)
33063207

33073208
def test_folding_type_param_in_type_alias(self):
3308-
code = "type foo[%s = 1 + 1] = 1"
3209+
code = "type foo[%s = (1, 2)] = 1"
33093210

3310-
unoptimized_binop = self.create_binop("+")
3211+
unoptimized_tuple = ast.Tuple(elts=[ast.Constant(1), ast.Constant(2)])
33113212
unoptimized_type_params = [
33123213
("T", "T", ast.TypeVar),
33133214
("**P", "P", ast.ParamSpec),
@@ -3319,19 +3220,80 @@ def test_folding_type_param_in_type_alias(self):
33193220
optimized_target = self.wrap_statement(
33203221
ast.TypeAlias(
33213222
name=ast.Name(id='foo', ctx=ast.Store()),
3322-
type_params=[type_param(name=name, default_value=ast.Constant(2))],
3223+
type_params=[type_param(name=name, default_value=ast.Constant((1, 2)))],
33233224
value=ast.Constant(value=1),
33243225
)
33253226
)
33263227
non_optimized_target = self.wrap_statement(
33273228
ast.TypeAlias(
33283229
name=ast.Name(id='foo', ctx=ast.Store()),
3329-
type_params=[type_param(name=name, default_value=unoptimized_binop)],
3230+
type_params=[type_param(name=name, default_value=unoptimized_tuple)],
33303231
value=ast.Constant(value=1),
33313232
)
33323233
)
33333234
self.assert_ast(result_code, non_optimized_target, optimized_target)
33343235

3236+
def test_folding_match_case_allowed_expressions(self):
3237+
def get_match_case_values(node):
3238+
result = []
3239+
if isinstance(node, ast.Constant):
3240+
result.append(node.value)
3241+
elif isinstance(node, ast.MatchValue):
3242+
result.extend(get_match_case_values(node.value))
3243+
elif isinstance(node, ast.MatchMapping):
3244+
for key in node.keys:
3245+
result.extend(get_match_case_values(key))
3246+
elif isinstance(node, ast.MatchSequence):
3247+
for pat in node.patterns:
3248+
result.extend(get_match_case_values(pat))
3249+
else:
3250+
self.fail(f"Unexpected node {node}")
3251+
return result
3252+
3253+
tests = [
3254+
("-0", [0]),
3255+
("-0.1", [-0.1]),
3256+
("-0j", [complex(0, 0)]),
3257+
("-0.1j", [complex(0, -0.1)]),
3258+
("1 + 2j", [complex(1, 2)]),
3259+
("1 - 2j", [complex(1, -2)]),
3260+
("1.1 + 2.1j", [complex(1.1, 2.1)]),
3261+
("1.1 - 2.1j", [complex(1.1, -2.1)]),
3262+
("-0 + 1j", [complex(0, 1)]),
3263+
("-0 - 1j", [complex(0, -1)]),
3264+
("-0.1 + 1.1j", [complex(-0.1, 1.1)]),
3265+
("-0.1 - 1.1j", [complex(-0.1, -1.1)]),
3266+
("{-0: 0}", [0]),
3267+
("{-0.1: 0}", [-0.1]),
3268+
("{-0j: 0}", [complex(0, 0)]),
3269+
("{-0.1j: 0}", [complex(0, -0.1)]),
3270+
("{1 + 2j: 0}", [complex(1, 2)]),
3271+
("{1 - 2j: 0}", [complex(1, -2)]),
3272+
("{1.1 + 2.1j: 0}", [complex(1.1, 2.1)]),
3273+
("{1.1 - 2.1j: 0}", [complex(1.1, -2.1)]),
3274+
("{-0 + 1j: 0}", [complex(0, 1)]),
3275+
("{-0 - 1j: 0}", [complex(0, -1)]),
3276+
("{-0.1 + 1.1j: 0}", [complex(-0.1, 1.1)]),
3277+
("{-0.1 - 1.1j: 0}", [complex(-0.1, -1.1)]),
3278+
("{-0: 0, 0 + 1j: 0, 0.1 + 1j: 0}", [0, complex(0, 1), complex(0.1, 1)]),
3279+
("[-0, -0.1, -0j, -0.1j]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
3280+
("[[[[-0, -0.1, -0j, -0.1j]]]]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
3281+
("[[-0, -0.1], -0j, -0.1j]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
3282+
("[[-0, -0.1], [-0j, -0.1j]]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
3283+
("(-0, -0.1, -0j, -0.1j)", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
3284+
("((((-0, -0.1, -0j, -0.1j))))", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
3285+
("((-0, -0.1), -0j, -0.1j)", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
3286+
("((-0, -0.1), (-0j, -0.1j))", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
3287+
]
3288+
for match_expr, constants in tests:
3289+
with self.subTest(match_expr):
3290+
src = f"match 0:\n\t case {match_expr}: pass"
3291+
tree = ast.parse(src, optimize=1)
3292+
match_stmt = tree.body[0]
3293+
case = match_stmt.cases[0]
3294+
values = get_match_case_values(case.pattern)
3295+
self.assertListEqual(constants, values)
3296+
33353297

33363298
if __name__ == '__main__':
33373299
if len(sys.argv) > 1 and sys.argv[1] == '--snapshot-update':

Lib/test/test_ast/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
def to_tuple(t):
2-
if t is None or isinstance(t, (str, int, complex, float, bytes)) or t is Ellipsis:
2+
if t is None or isinstance(t, (str, int, complex, float, bytes, tuple)) or t is Ellipsis:
33
return t
44
elif isinstance(t, list):
55
return [to_tuple(e) for e in t]

Lib/test/test_builtin.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def test_compile_async_generator(self):
555555
self.assertEqual(type(glob['ticker']()), AsyncGeneratorType)
556556

557557
def test_compile_ast(self):
558-
args = ("a*(1+2)", "f.py", "exec")
558+
args = ("a*(1,2)", "f.py", "exec")
559559
raw = compile(*args, flags = ast.PyCF_ONLY_AST).body[0]
560560
opt1 = compile(*args, flags = ast.PyCF_OPTIMIZED_AST).body[0]
561561
opt2 = compile(ast.parse(args[0]), *args[1:], flags = ast.PyCF_OPTIMIZED_AST).body[0]
@@ -566,17 +566,14 @@ def test_compile_ast(self):
566566
self.assertIsInstance(tree.value.left, ast.Name)
567567
self.assertEqual(tree.value.left.id, 'a')
568568

569-
raw_right = raw.value.right # expect BinOp(1, '+', 2)
570-
self.assertIsInstance(raw_right, ast.BinOp)
571-
self.assertIsInstance(raw_right.left, ast.Constant)
572-
self.assertEqual(raw_right.left.value, 1)
573-
self.assertIsInstance(raw_right.right, ast.Constant)
574-
self.assertEqual(raw_right.right.value, 2)
569+
raw_right = raw.value.right # expect Tuple((1, 2))
570+
self.assertIsInstance(raw_right, ast.Tuple)
571+
self.assertListEqual([elt.value for elt in raw_right.elts], [1, 2])
575572

576573
for opt in [opt1, opt2]:
577-
opt_right = opt.value.right # expect Constant(3)
574+
opt_right = opt.value.right # expect Constant((1,2))
578575
self.assertIsInstance(opt_right, ast.Constant)
579-
self.assertEqual(opt_right.value, 3)
576+
self.assertEqual(opt_right.value, (1, 2))
580577

581578
def test_delattr(self):
582579
sys.spam = 1

0 commit comments

Comments
 (0)