@@ -154,18 +154,17 @@ def test_optimization_levels__debug__(self):
154
154
self .assertEqual (res .body [0 ].value .id , expected )
155
155
156
156
def test_optimization_levels_const_folding (self ):
157
- folded = ('Expr' , (1 , 0 , 1 , 5 ), ('Constant' , (1 , 0 , 1 , 5 ), 3 , None ))
158
- not_folded = ('Expr' , (1 , 0 , 1 , 5 ),
159
- ('BinOp' , (1 , 0 , 1 , 5 ),
160
- ('Constant' , (1 , 0 , 1 , 1 ), 1 , None ),
161
- ('Add' ,),
162
- ('Constant' , (1 , 4 , 1 , 5 ), 2 , None )))
157
+ folded = ('Expr' , (1 , 0 , 1 , 6 ), ('Constant' , (1 , 0 , 1 , 6 ), (1 , 2 ), None ))
158
+ not_folded = ('Expr' , (1 , 0 , 1 , 6 ),
159
+ ('Tuple' , (1 , 0 , 1 , 6 ),
160
+ [('Constant' , (1 , 1 , 1 , 2 ), 1 , None ),
161
+ ('Constant' , (1 , 4 , 1 , 5 ), 2 , None )], ('Load' ,)))
163
162
164
163
cases = [(- 1 , not_folded ), (0 , not_folded ), (1 , folded ), (2 , folded )]
165
164
for (optval , expected ) in cases :
166
165
with self .subTest (optval = optval ):
167
- tree1 = ast .parse ("1 + 2 " , optimize = optval )
168
- tree2 = ast .parse (ast .parse ("1 + 2 " ), optimize = optval )
166
+ tree1 = ast .parse ("(1, 2) " , optimize = optval )
167
+ tree2 = ast .parse (ast .parse ("(1, 2) " ), optimize = optval )
169
168
for tree in [tree1 , tree2 ]:
170
169
res = to_tuple (tree .body [0 ])
171
170
self .assertEqual (res , expected )
@@ -3089,27 +3088,6 @@ def test_cli_file_input(self):
3089
3088
3090
3089
3091
3090
class ASTOptimiziationTests (unittest .TestCase ):
3092
- binop = {
3093
- "+" : ast .Add (),
3094
- "-" : ast .Sub (),
3095
- "*" : ast .Mult (),
3096
- "/" : ast .Div (),
3097
- "%" : ast .Mod (),
3098
- "<<" : ast .LShift (),
3099
- ">>" : ast .RShift (),
3100
- "|" : ast .BitOr (),
3101
- "^" : ast .BitXor (),
3102
- "&" : ast .BitAnd (),
3103
- "//" : ast .FloorDiv (),
3104
- "**" : ast .Pow (),
3105
- }
3106
-
3107
- unaryop = {
3108
- "~" : ast .Invert (),
3109
- "+" : ast .UAdd (),
3110
- "-" : ast .USub (),
3111
- }
3112
-
3113
3091
def wrap_expr (self , expr ):
3114
3092
return ast .Module (body = [ast .Expr (value = expr )])
3115
3093
@@ -3141,83 +3119,6 @@ def assert_ast(self, code, non_optimized_target, optimized_target):
3141
3119
f"{ ast .dump (optimized_tree )} " ,
3142
3120
)
3143
3121
3144
- def create_binop (self , operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3145
- return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3146
-
3147
- def test_folding_binop (self ):
3148
- code = "1 %s 1"
3149
- operators = self .binop .keys ()
3150
-
3151
- for op in operators :
3152
- result_code = code % op
3153
- non_optimized_target = self .wrap_expr (self .create_binop (op ))
3154
- optimized_target = self .wrap_expr (ast .Constant (value = eval (result_code )))
3155
-
3156
- with self .subTest (
3157
- result_code = result_code ,
3158
- non_optimized_target = non_optimized_target ,
3159
- optimized_target = optimized_target
3160
- ):
3161
- self .assert_ast (result_code , non_optimized_target , optimized_target )
3162
-
3163
- # Multiplication of constant tuples must be folded
3164
- code = "(1,) * 3"
3165
- non_optimized_target = self .wrap_expr (self .create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
3166
- optimized_target = self .wrap_expr (ast .Constant (eval (code )))
3167
-
3168
- self .assert_ast (code , non_optimized_target , optimized_target )
3169
-
3170
- def test_folding_unaryop (self ):
3171
- code = "%s1"
3172
- operators = self .unaryop .keys ()
3173
-
3174
- def create_unaryop (operand ):
3175
- return ast .UnaryOp (op = self .unaryop [operand ], operand = ast .Constant (1 ))
3176
-
3177
- for op in operators :
3178
- result_code = code % op
3179
- non_optimized_target = self .wrap_expr (create_unaryop (op ))
3180
- optimized_target = self .wrap_expr (ast .Constant (eval (result_code )))
3181
-
3182
- with self .subTest (
3183
- result_code = result_code ,
3184
- non_optimized_target = non_optimized_target ,
3185
- optimized_target = optimized_target
3186
- ):
3187
- self .assert_ast (result_code , non_optimized_target , optimized_target )
3188
-
3189
- def test_folding_not (self ):
3190
- code = "not (1 %s (1,))"
3191
- operators = {
3192
- "in" : ast .In (),
3193
- "is" : ast .Is (),
3194
- }
3195
- opt_operators = {
3196
- "is" : ast .IsNot (),
3197
- "in" : ast .NotIn (),
3198
- }
3199
-
3200
- def create_notop (operand ):
3201
- return ast .UnaryOp (op = ast .Not (), operand = ast .Compare (
3202
- left = ast .Constant (value = 1 ),
3203
- ops = [operators [operand ]],
3204
- comparators = [ast .Tuple (elts = [ast .Constant (value = 1 )])]
3205
- ))
3206
-
3207
- for op in operators .keys ():
3208
- result_code = code % op
3209
- non_optimized_target = self .wrap_expr (create_notop (op ))
3210
- optimized_target = self .wrap_expr (
3211
- ast .Compare (left = ast .Constant (1 ), ops = [opt_operators [op ]], comparators = [ast .Constant (value = (1 ,))])
3212
- )
3213
-
3214
- with self .subTest (
3215
- result_code = result_code ,
3216
- non_optimized_target = non_optimized_target ,
3217
- optimized_target = optimized_target
3218
- ):
3219
- self .assert_ast (result_code , non_optimized_target , optimized_target )
3220
-
3221
3122
def test_folding_format (self ):
3222
3123
code = "'%s' % (a,)"
3223
3124
@@ -3247,9 +3148,9 @@ def test_folding_tuple(self):
3247
3148
self .assert_ast (code , non_optimized_target , optimized_target )
3248
3149
3249
3150
def test_folding_type_param_in_function_def (self ):
3250
- code = "def foo[%s = 1 + 1 ](): pass"
3151
+ code = "def foo[%s = (1, 2) ](): pass"
3251
3152
3252
- unoptimized_binop = self . create_binop ( "+" )
3153
+ unoptimized_tuple = ast . Tuple ( elts = [ ast . Constant ( 1 ), ast . Constant ( 2 )] )
3253
3154
unoptimized_type_params = [
3254
3155
("T" , "T" , ast .TypeVar ),
3255
3156
("**P" , "P" , ast .ParamSpec ),
@@ -3263,23 +3164,23 @@ def test_folding_type_param_in_function_def(self):
3263
3164
name = 'foo' ,
3264
3165
args = ast .arguments (),
3265
3166
body = [ast .Pass ()],
3266
- type_params = [type_param (name = name , default_value = ast .Constant (2 ))]
3167
+ type_params = [type_param (name = name , default_value = ast .Constant (( 1 , 2 ) ))]
3267
3168
)
3268
3169
)
3269
3170
non_optimized_target = self .wrap_statement (
3270
3171
ast .FunctionDef (
3271
3172
name = 'foo' ,
3272
3173
args = ast .arguments (),
3273
3174
body = [ast .Pass ()],
3274
- type_params = [type_param (name = name , default_value = unoptimized_binop )]
3175
+ type_params = [type_param (name = name , default_value = unoptimized_tuple )]
3275
3176
)
3276
3177
)
3277
3178
self .assert_ast (result_code , non_optimized_target , optimized_target )
3278
3179
3279
3180
def test_folding_type_param_in_class_def (self ):
3280
- code = "class foo[%s = 1 + 1 ]: pass"
3181
+ code = "class foo[%s = (1, 2) ]: pass"
3281
3182
3282
- unoptimized_binop = self . create_binop ( "+" )
3183
+ unoptimized_tuple = ast . Tuple ( elts = [ ast . Constant ( 1 ), ast . Constant ( 2 )] )
3283
3184
unoptimized_type_params = [
3284
3185
("T" , "T" , ast .TypeVar ),
3285
3186
("**P" , "P" , ast .ParamSpec ),
@@ -3292,22 +3193,22 @@ def test_folding_type_param_in_class_def(self):
3292
3193
ast .ClassDef (
3293
3194
name = 'foo' ,
3294
3195
body = [ast .Pass ()],
3295
- type_params = [type_param (name = name , default_value = ast .Constant (2 ))]
3196
+ type_params = [type_param (name = name , default_value = ast .Constant (( 1 , 2 ) ))]
3296
3197
)
3297
3198
)
3298
3199
non_optimized_target = self .wrap_statement (
3299
3200
ast .ClassDef (
3300
3201
name = 'foo' ,
3301
3202
body = [ast .Pass ()],
3302
- type_params = [type_param (name = name , default_value = unoptimized_binop )]
3203
+ type_params = [type_param (name = name , default_value = unoptimized_tuple )]
3303
3204
)
3304
3205
)
3305
3206
self .assert_ast (result_code , non_optimized_target , optimized_target )
3306
3207
3307
3208
def test_folding_type_param_in_type_alias (self ):
3308
- code = "type foo[%s = 1 + 1 ] = 1"
3209
+ code = "type foo[%s = (1, 2) ] = 1"
3309
3210
3310
- unoptimized_binop = self . create_binop ( "+" )
3211
+ unoptimized_tuple = ast . Tuple ( elts = [ ast . Constant ( 1 ), ast . Constant ( 2 )] )
3311
3212
unoptimized_type_params = [
3312
3213
("T" , "T" , ast .TypeVar ),
3313
3214
("**P" , "P" , ast .ParamSpec ),
@@ -3319,19 +3220,80 @@ def test_folding_type_param_in_type_alias(self):
3319
3220
optimized_target = self .wrap_statement (
3320
3221
ast .TypeAlias (
3321
3222
name = ast .Name (id = 'foo' , ctx = ast .Store ()),
3322
- type_params = [type_param (name = name , default_value = ast .Constant (2 ))],
3223
+ type_params = [type_param (name = name , default_value = ast .Constant (( 1 , 2 ) ))],
3323
3224
value = ast .Constant (value = 1 ),
3324
3225
)
3325
3226
)
3326
3227
non_optimized_target = self .wrap_statement (
3327
3228
ast .TypeAlias (
3328
3229
name = ast .Name (id = 'foo' , ctx = ast .Store ()),
3329
- type_params = [type_param (name = name , default_value = unoptimized_binop )],
3230
+ type_params = [type_param (name = name , default_value = unoptimized_tuple )],
3330
3231
value = ast .Constant (value = 1 ),
3331
3232
)
3332
3233
)
3333
3234
self .assert_ast (result_code , non_optimized_target , optimized_target )
3334
3235
3236
+ def test_folding_match_case_allowed_expressions (self ):
3237
+ def get_match_case_values (node ):
3238
+ result = []
3239
+ if isinstance (node , ast .Constant ):
3240
+ result .append (node .value )
3241
+ elif isinstance (node , ast .MatchValue ):
3242
+ result .extend (get_match_case_values (node .value ))
3243
+ elif isinstance (node , ast .MatchMapping ):
3244
+ for key in node .keys :
3245
+ result .extend (get_match_case_values (key ))
3246
+ elif isinstance (node , ast .MatchSequence ):
3247
+ for pat in node .patterns :
3248
+ result .extend (get_match_case_values (pat ))
3249
+ else :
3250
+ self .fail (f"Unexpected node { node } " )
3251
+ return result
3252
+
3253
+ tests = [
3254
+ ("-0" , [0 ]),
3255
+ ("-0.1" , [- 0.1 ]),
3256
+ ("-0j" , [complex (0 , 0 )]),
3257
+ ("-0.1j" , [complex (0 , - 0.1 )]),
3258
+ ("1 + 2j" , [complex (1 , 2 )]),
3259
+ ("1 - 2j" , [complex (1 , - 2 )]),
3260
+ ("1.1 + 2.1j" , [complex (1.1 , 2.1 )]),
3261
+ ("1.1 - 2.1j" , [complex (1.1 , - 2.1 )]),
3262
+ ("-0 + 1j" , [complex (0 , 1 )]),
3263
+ ("-0 - 1j" , [complex (0 , - 1 )]),
3264
+ ("-0.1 + 1.1j" , [complex (- 0.1 , 1.1 )]),
3265
+ ("-0.1 - 1.1j" , [complex (- 0.1 , - 1.1 )]),
3266
+ ("{-0: 0}" , [0 ]),
3267
+ ("{-0.1: 0}" , [- 0.1 ]),
3268
+ ("{-0j: 0}" , [complex (0 , 0 )]),
3269
+ ("{-0.1j: 0}" , [complex (0 , - 0.1 )]),
3270
+ ("{1 + 2j: 0}" , [complex (1 , 2 )]),
3271
+ ("{1 - 2j: 0}" , [complex (1 , - 2 )]),
3272
+ ("{1.1 + 2.1j: 0}" , [complex (1.1 , 2.1 )]),
3273
+ ("{1.1 - 2.1j: 0}" , [complex (1.1 , - 2.1 )]),
3274
+ ("{-0 + 1j: 0}" , [complex (0 , 1 )]),
3275
+ ("{-0 - 1j: 0}" , [complex (0 , - 1 )]),
3276
+ ("{-0.1 + 1.1j: 0}" , [complex (- 0.1 , 1.1 )]),
3277
+ ("{-0.1 - 1.1j: 0}" , [complex (- 0.1 , - 1.1 )]),
3278
+ ("{-0: 0, 0 + 1j: 0, 0.1 + 1j: 0}" , [0 , complex (0 , 1 ), complex (0.1 , 1 )]),
3279
+ ("[-0, -0.1, -0j, -0.1j]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3280
+ ("[[[[-0, -0.1, -0j, -0.1j]]]]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3281
+ ("[[-0, -0.1], -0j, -0.1j]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3282
+ ("[[-0, -0.1], [-0j, -0.1j]]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3283
+ ("(-0, -0.1, -0j, -0.1j)" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3284
+ ("((((-0, -0.1, -0j, -0.1j))))" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3285
+ ("((-0, -0.1), -0j, -0.1j)" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3286
+ ("((-0, -0.1), (-0j, -0.1j))" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3287
+ ]
3288
+ for match_expr , constants in tests :
3289
+ with self .subTest (match_expr ):
3290
+ src = f"match 0:\n \t case { match_expr } : pass"
3291
+ tree = ast .parse (src , optimize = 1 )
3292
+ match_stmt = tree .body [0 ]
3293
+ case = match_stmt .cases [0 ]
3294
+ values = get_match_case_values (case .pattern )
3295
+ self .assertListEqual (constants , values )
3296
+
3335
3297
3336
3298
if __name__ == '__main__' :
3337
3299
if len (sys .argv ) > 1 and sys .argv [1 ] == '--snapshot-update' :
0 commit comments