Skip to content

Commit e1117c3

Browse files
authored
[mypyc] Precompute set literals for "in" ops against / iteration over set literals (#14409)
Towards mypyc/mypyc#726. (There's a Python compatibility bug that needs to be fixed before the issue can be closed.) For example, the set literals here are now precomputed as frozensets at module initialization. ``` x in {1, 2.0, "3"} x not in {1, 2.0, "3"} for _ in {1, 2.0, "3"}: ... ``` Set literal items supported: - Anything supported by `irbuild.constant_fold.constant_fold_expr()` - String and integer literals - Final references to int/str values - Certain int and str unary/binary ops that evaluate to a constant value - `None`, `True`, and `False` - Float, byte, and complex literals - Tuple literals with only items listed above **Results** (using gcc-9 on 64-bit Ubuntu) Master @ 98cc165 running in_set .......... interpreted: 0.495790s (avg of 5 iterations; stdev 6.8%) compiled: 0.810029s (avg of 5 iterations; stdev 1.5%) compiled is 0.612x faster running set_literal_iteration ......................................................................................... interpreted: 0.020255s (avg of 45 iterations; stdev 2.5%) compiled: 0.016336s (avg of 45 iterations; stdev 1.8%) compiled is 1.240x faster This PR running in_set .......... interpreted: 0.502020s (avg of 5 iterations; stdev 1.1%) compiled: 0.390281s (avg of 5 iterations; stdev 6.2%) compiled is 1.286x faster running set_literal_iteration .............................................................................................. interpreted: 0.019917s (avg of 47 iterations; stdev 2.2%) compiled: 0.007134s (avg of 47 iterations; stdev 2.6%) compiled is 2.792x faster Benchmarks can be found here: mypyc/mypyc-benchmarks#32
1 parent 4ec6ea5 commit e1117c3

File tree

15 files changed

+403
-32
lines changed

15 files changed

+403
-32
lines changed

mypyc/analysis/ircheck.py

+14
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,15 @@ def check_tuple_items_valid_literals(self, op: LoadLiteral, t: tuple[object, ...
252252
if isinstance(x, tuple):
253253
self.check_tuple_items_valid_literals(op, x)
254254

255+
def check_frozenset_items_valid_literals(self, op: LoadLiteral, s: frozenset[object]) -> None:
256+
for x in s:
257+
if x is None or isinstance(x, (str, bytes, bool, int, float, complex)):
258+
pass
259+
elif isinstance(x, tuple):
260+
self.check_tuple_items_valid_literals(op, x)
261+
else:
262+
self.fail(op, f"Invalid type for item of frozenset literal: {type(x)})")
263+
255264
def visit_load_literal(self, op: LoadLiteral) -> None:
256265
expected_type = None
257266
if op.value is None:
@@ -271,6 +280,11 @@ def visit_load_literal(self, op: LoadLiteral) -> None:
271280
elif isinstance(op.value, tuple):
272281
expected_type = "builtins.tuple"
273282
self.check_tuple_items_valid_literals(op, op.value)
283+
elif isinstance(op.value, frozenset):
284+
# There's no frozenset_rprimitive type since it'd be pretty useless so we just pretend
285+
# it's a set (when it's really a frozenset).
286+
expected_type = "builtins.set"
287+
self.check_frozenset_items_valid_literals(op, op.value)
274288

275289
assert expected_type is not None, "Missed a case for LoadLiteral check"
276290

mypyc/codegen/emitmodule.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,9 @@ def generate_literal_tables(self) -> None:
669669
# Descriptions of tuple literals
670670
init_tuple = c_array_initializer(literals.encoded_tuple_values())
671671
self.declare_global("const int []", "CPyLit_Tuple", initializer=init_tuple)
672+
# Descriptions of frozenset literals
673+
init_frozenset = c_array_initializer(literals.encoded_frozenset_values())
674+
self.declare_global("const int []", "CPyLit_FrozenSet", initializer=init_frozenset)
672675

673676
def generate_export_table(self, decl_emitter: Emitter, code_emitter: Emitter) -> None:
674677
"""Generate the declaration and definition of the group's export struct.
@@ -839,7 +842,7 @@ def generate_globals_init(self, emitter: Emitter) -> None:
839842
for symbol, fixup in self.simple_inits:
840843
emitter.emit_line(f"{symbol} = {fixup};")
841844

842-
values = "CPyLit_Str, CPyLit_Bytes, CPyLit_Int, CPyLit_Float, CPyLit_Complex, CPyLit_Tuple"
845+
values = "CPyLit_Str, CPyLit_Bytes, CPyLit_Int, CPyLit_Float, CPyLit_Complex, CPyLit_Tuple, CPyLit_FrozenSet"
843846
emitter.emit_lines(
844847
f"if (CPyStatics_Initialize(CPyStatics, {values}) < 0) {{", "return -1;", "}"
845848
)

mypyc/codegen/literals.py

+33-13
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3-
from typing import Any, Tuple, Union, cast
3+
from typing import Any, Dict, FrozenSet, List, Tuple, Union, cast
44
from typing_extensions import Final
55

6-
# Supported Python literal types. All tuple items must have supported
6+
# Supported Python literal types. All tuple / frozenset items must have supported
77
# literal types as well, but we can't represent the type precisely.
8-
LiteralValue = Union[str, bytes, int, bool, float, complex, Tuple[object, ...], None]
9-
8+
LiteralValue = Union[
9+
str, bytes, int, bool, float, complex, Tuple[object, ...], FrozenSet[object], None
10+
]
1011

1112
# Some literals are singletons and handled specially (None, False and True)
1213
NUM_SINGLETONS: Final = 3
@@ -23,6 +24,7 @@ def __init__(self) -> None:
2324
self.float_literals: dict[float, int] = {}
2425
self.complex_literals: dict[complex, int] = {}
2526
self.tuple_literals: dict[tuple[object, ...], int] = {}
27+
self.frozenset_literals: dict[frozenset[object], int] = {}
2628

2729
def record_literal(self, value: LiteralValue) -> None:
2830
"""Ensure that the literal value is available in generated code."""
@@ -55,6 +57,12 @@ def record_literal(self, value: LiteralValue) -> None:
5557
for item in value:
5658
self.record_literal(cast(Any, item))
5759
tuple_literals[value] = len(tuple_literals)
60+
elif isinstance(value, frozenset):
61+
frozenset_literals = self.frozenset_literals
62+
if value not in frozenset_literals:
63+
for item in value:
64+
self.record_literal(cast(Any, item))
65+
frozenset_literals[value] = len(frozenset_literals)
5866
else:
5967
assert False, "invalid literal: %r" % value
6068

@@ -86,6 +94,9 @@ def literal_index(self, value: LiteralValue) -> int:
8694
n += len(self.complex_literals)
8795
if isinstance(value, tuple):
8896
return n + self.tuple_literals[value]
97+
n += len(self.tuple_literals)
98+
if isinstance(value, frozenset):
99+
return n + self.frozenset_literals[value]
89100
assert False, "invalid literal: %r" % value
90101

91102
def num_literals(self) -> int:
@@ -98,6 +109,7 @@ def num_literals(self) -> int:
98109
+ len(self.float_literals)
99110
+ len(self.complex_literals)
100111
+ len(self.tuple_literals)
112+
+ len(self.frozenset_literals)
101113
)
102114

103115
# The following methods return the C encodings of literal values
@@ -119,24 +131,32 @@ def encoded_complex_values(self) -> list[str]:
119131
return _encode_complex_values(self.complex_literals)
120132

121133
def encoded_tuple_values(self) -> list[str]:
122-
"""Encode tuple values into a C array.
134+
return self._encode_collection_values(self.tuple_literals)
135+
136+
def encoded_frozenset_values(self) -> List[str]:
137+
return self._encode_collection_values(self.frozenset_literals)
138+
139+
def _encode_collection_values(
140+
self, values: dict[tuple[object, ...], int] | dict[frozenset[object], int]
141+
) -> list[str]:
142+
"""Encode tuple/frozenset values into a C array.
123143
124144
The format of the result is like this:
125145
126-
<number of tuples>
127-
<length of the first tuple>
146+
<number of collections>
147+
<length of the first collection>
128148
<literal index of first item>
129149
...
130150
<literal index of last item>
131-
<length of the second tuple>
151+
<length of the second collection>
132152
...
133153
"""
134-
values = self.tuple_literals
135-
value_by_index = {index: value for value, index in values.items()}
154+
# FIXME: https://github.com/mypyc/mypyc/issues/965
155+
value_by_index = {index: value for value, index in cast(Dict[Any, int], values).items()}
136156
result = []
137-
num = len(values)
138-
result.append(str(num))
139-
for i in range(num):
157+
count = len(values)
158+
result.append(str(count))
159+
for i in range(count):
140160
value = value_by_index[i]
141161
result.append(str(len(value)))
142162
for item in value:

mypyc/ir/ops.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040

4141
if TYPE_CHECKING:
42+
from mypyc.codegen.literals import LiteralValue
4243
from mypyc.ir.class_ir import ClassIR
4344
from mypyc.ir.func_ir import FuncDecl, FuncIR
4445

@@ -588,7 +589,7 @@ class LoadLiteral(RegisterOp):
588589
This is used to load a static PyObject * value corresponding to
589590
a literal of one of the supported types.
590591
591-
Tuple literals must contain only valid literal values as items.
592+
Tuple / frozenset literals must contain only valid literal values as items.
592593
593594
NOTE: You can use this to load boxed (Python) int objects. Use
594595
Integer to load unboxed, tagged integers or fixed-width,
@@ -603,11 +604,7 @@ class LoadLiteral(RegisterOp):
603604
error_kind = ERR_NEVER
604605
is_borrowed = True
605606

606-
def __init__(
607-
self,
608-
value: None | str | bytes | bool | int | float | complex | tuple[object, ...],
609-
rtype: RType,
610-
) -> None:
607+
def __init__(self, value: LiteralValue, rtype: RType) -> None:
611608
self.value = value
612609
self.type = rtype
613610

mypyc/ir/pprint.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,18 @@ def visit_load_literal(self, op: LoadLiteral) -> str:
106106
# it explicit that this is a Python object.
107107
if isinstance(op.value, int):
108108
prefix = "object "
109-
return self.format("%r = %s%s", op, prefix, repr(op.value))
109+
110+
rvalue = repr(op.value)
111+
if isinstance(op.value, frozenset):
112+
# We need to generate a string representation that won't vary
113+
# run-to-run because sets are unordered, otherwise we may get
114+
# spurious irbuild test failures.
115+
#
116+
# Sorting by the item's string representation is a bit of a
117+
# hack, but it's stable and won't cause TypeErrors.
118+
formatted_items = [repr(i) for i in sorted(op.value, key=str)]
119+
rvalue = "frozenset({" + ", ".join(formatted_items) + "})"
120+
return self.format("%r = %s%s", op, prefix, rvalue)
110121

111122
def visit_get_attr(self, op: GetAttr) -> str:
112123
return self.format("%r = %s%r.%s", op, self.borrow_prefix(op), op.obj, op.attr)

mypyc/irbuild/builder.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@
118118
AssignmentTargetRegister,
119119
AssignmentTargetTuple,
120120
)
121-
from mypyc.irbuild.util import is_constant
121+
from mypyc.irbuild.util import bytes_from_str, is_constant
122122
from mypyc.options import CompilerOptions
123123
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
124124
from mypyc.primitives.generic_ops import iter_op, next_op, py_setattr_op
@@ -296,8 +296,7 @@ def load_bytes_from_str_literal(self, value: str) -> Value:
296296
are stored in BytesExpr.value, whose type is 'str' not 'bytes'.
297297
Thus we perform a special conversion here.
298298
"""
299-
bytes_value = bytes(value, "utf8").decode("unicode-escape").encode("raw-unicode-escape")
300-
return self.builder.load_bytes(bytes_value)
299+
return self.builder.load_bytes(bytes_from_str(value))
301300

302301
def load_int(self, value: int) -> Value:
303302
return self.builder.load_int(value)
@@ -886,7 +885,7 @@ def get_dict_base_type(self, expr: Expression) -> Instance:
886885
This is useful for dict subclasses like SymbolTable.
887886
"""
888887
target_type = get_proper_type(self.types[expr])
889-
assert isinstance(target_type, Instance)
888+
assert isinstance(target_type, Instance), target_type
890889
dict_base = next(base for base in target_type.type.mro if base.fullname == "builtins.dict")
891890
return map_instance_to_supertype(target_type, dict_base)
892891

mypyc/irbuild/expression.py

+71-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from __future__ import annotations
88

9-
from typing import Callable, cast
9+
from typing import Callable, Sequence, cast
1010

1111
from mypy.nodes import (
1212
ARG_POS,
@@ -55,6 +55,7 @@
5555
ComparisonOp,
5656
Integer,
5757
LoadAddress,
58+
LoadLiteral,
5859
RaiseStandardError,
5960
Register,
6061
TupleGet,
@@ -63,12 +64,14 @@
6364
)
6465
from mypyc.ir.rtypes import (
6566
RTuple,
67+
bool_rprimitive,
6668
int_rprimitive,
6769
is_fixed_width_rtype,
6870
is_int_rprimitive,
6971
is_list_rprimitive,
7072
is_none_rprimitive,
7173
object_rprimitive,
74+
set_rprimitive,
7275
)
7376
from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional
7477
from mypyc.irbuild.builder import IRBuilder, int_borrow_friendly_op
@@ -86,14 +89,15 @@
8689
tokenizer_printf_style,
8790
)
8891
from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization
92+
from mypyc.irbuild.util import bytes_from_str
8993
from mypyc.primitives.bytes_ops import bytes_slice_op
9094
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, dict_set_item_op
9195
from mypyc.primitives.generic_ops import iter_op
9296
from mypyc.primitives.int_ops import int_comparison_op_mapping
9397
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
9498
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
9599
from mypyc.primitives.registry import CFunctionDescription, builtin_names
96-
from mypyc.primitives.set_ops import set_add_op, set_update_op
100+
from mypyc.primitives.set_ops import set_add_op, set_in_op, set_update_op
97101
from mypyc.primitives.str_ops import str_slice_op
98102
from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op
99103

@@ -613,6 +617,54 @@ def transform_conditional_expr(builder: IRBuilder, expr: ConditionalExpr) -> Val
613617
return target
614618

615619

620+
def set_literal_values(builder: IRBuilder, items: Sequence[Expression]) -> list[object] | None:
621+
values: list[object] = []
622+
for item in items:
623+
const_value = constant_fold_expr(builder, item)
624+
if const_value is not None:
625+
values.append(const_value)
626+
continue
627+
628+
if isinstance(item, RefExpr):
629+
if item.fullname == "builtins.None":
630+
values.append(None)
631+
elif item.fullname == "builtins.True":
632+
values.append(True)
633+
elif item.fullname == "builtins.False":
634+
values.append(False)
635+
elif isinstance(item, (BytesExpr, FloatExpr, ComplexExpr)):
636+
# constant_fold_expr() doesn't handle these (yet?)
637+
v = bytes_from_str(item.value) if isinstance(item, BytesExpr) else item.value
638+
values.append(v)
639+
elif isinstance(item, TupleExpr):
640+
tuple_values = set_literal_values(builder, item.items)
641+
if tuple_values is not None:
642+
values.append(tuple(tuple_values))
643+
644+
if len(values) != len(items):
645+
# Bail if not all items can be converted into values.
646+
return None
647+
return values
648+
649+
650+
def precompute_set_literal(builder: IRBuilder, s: SetExpr) -> Value | None:
651+
"""Try to pre-compute a frozenset literal during module initialization.
652+
653+
Return None if it's not possible.
654+
655+
Supported items:
656+
- Anything supported by irbuild.constant_fold.constant_fold_expr()
657+
- None, True, and False
658+
- Float, byte, and complex literals
659+
- Tuple literals with only items listed above
660+
"""
661+
values = set_literal_values(builder, s.items)
662+
if values is not None:
663+
return builder.add(LoadLiteral(frozenset(values), set_rprimitive))
664+
665+
return None
666+
667+
616668
def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
617669
# x in (...)/[...]
618670
# x not in (...)/[...]
@@ -666,6 +718,23 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
666718
else:
667719
return builder.true()
668720

721+
# x in {...}
722+
# x not in {...}
723+
if (
724+
first_op in ("in", "not in")
725+
and len(e.operators) == 1
726+
and isinstance(e.operands[1], SetExpr)
727+
):
728+
set_literal = precompute_set_literal(builder, e.operands[1])
729+
if set_literal is not None:
730+
lhs = e.operands[0]
731+
result = builder.builder.call_c(
732+
set_in_op, [builder.accept(lhs), set_literal], e.line, bool_rprimitive
733+
)
734+
if first_op == "not in":
735+
return builder.unary_op(result, "not", e.line)
736+
return result
737+
669738
if len(e.operators) == 1:
670739
# Special some common simple cases
671740
if first_op in ("is", "is not"):

mypyc/irbuild/for_helpers.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Lvalue,
1818
MemberExpr,
1919
RefExpr,
20+
SetExpr,
2021
TupleExpr,
2122
TypeAlias,
2223
)
@@ -469,12 +470,22 @@ def make_for_loop_generator(
469470
for_dict_gen.init(expr_reg, target_type)
470471
return for_dict_gen
471472

473+
iterable_expr_reg: Value | None = None
474+
if isinstance(expr, SetExpr):
475+
# Special case "for x in <set literal>".
476+
from mypyc.irbuild.expression import precompute_set_literal
477+
478+
set_literal = precompute_set_literal(builder, expr)
479+
if set_literal is not None:
480+
iterable_expr_reg = set_literal
481+
472482
# Default to a generic for loop.
473-
expr_reg = builder.accept(expr)
483+
if iterable_expr_reg is None:
484+
iterable_expr_reg = builder.accept(expr)
474485
for_obj = ForIterable(builder, index, body_block, loop_exit, line, nested)
475486
item_type = builder._analyze_iterable_item_type(expr)
476487
item_rtype = builder.type_to_rtype(item_type)
477-
for_obj.init(expr_reg, item_rtype)
488+
for_obj.init(iterable_expr_reg, item_rtype)
478489
return for_obj
479490

480491

mypyc/irbuild/util.py

+10
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,13 @@ def is_constant(e: Expression) -> bool:
177177
)
178178
)
179179
)
180+
181+
182+
def bytes_from_str(value: str) -> bytes:
183+
"""Convert a string representing bytes into actual bytes.
184+
185+
This is needed because the literal characters of BytesExpr (the
186+
characters inside b'') are stored in BytesExpr.value, whose type is
187+
'str' not 'bytes'.
188+
"""
189+
return bytes(value, "utf8").decode("unicode-escape").encode("raw-unicode-escape")

0 commit comments

Comments
 (0)