diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index c09872ca3826..36105e3538d8 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -1210,6 +1210,7 @@ class RaiseStandardError(RegisterOp): RUNTIME_ERROR: Final = "RuntimeError" NAME_ERROR: Final = "NameError" ZERO_DIVISION_ERROR: Final = "ZeroDivisionError" + INDEX_ERROR: Final = "IndexError" def __init__(self, class_name: str, value: str | Value | None, line: int) -> None: super().__init__(line) diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index 2ef92efa2f0b..1335f818bebc 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -520,6 +520,8 @@ def __hash__(self) -> int: ] } +bytes_writer_rprimitive: Final = KNOWN_NATIVE_TYPES["librt.strings.BytesWriter"] + def is_native_rprimitive(rtype: RType) -> bool: return isinstance(rtype, RPrimitive) and rtype.name in KNOWN_NATIVE_TYPES diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 04a55fb257f0..d0e91125e80f 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -97,6 +97,7 @@ tokenizer_printf_style, ) from mypyc.irbuild.specialize import ( + apply_dunder_specialization, apply_function_specialization, apply_method_specialization, translate_object_new, @@ -587,6 +588,12 @@ def transform_index_expr(builder: IRBuilder, expr: IndexExpr) -> Value: is_list = is_list_rprimitive(base_type) can_borrow_base = is_list and is_borrow_friendly_expr(builder, index) + # Check for dunder specialization for non-slice indexing + if not isinstance(index, SliceExpr): + specialized = apply_dunder_specialization(builder, expr.base, [index], "__getitem__", expr) + if specialized is not None: + return specialized + base = builder.accept(expr.base, can_borrow=can_borrow_base) if isinstance(base.type, RTuple): diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 6e08c4e8b2f4..c0ed5f9f69da 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -55,6 +55,7 @@ RType, bool_rprimitive, bytes_rprimitive, + bytes_writer_rprimitive, c_int_rprimitive, dict_rprimitive, int16_rprimitive, @@ -104,6 +105,12 @@ from mypyc.primitives.float_ops import isinstance_float from mypyc.primitives.generic_ops import generic_setattr, setup_object from mypyc.primitives.int_ops import isinstance_int +from mypyc.primitives.librt_strings_ops import ( + bytes_writer_adjust_index_op, + bytes_writer_get_item_unsafe_op, + bytes_writer_range_check_op, + bytes_writer_set_item_unsafe_op, +) from mypyc.primitives.list_ops import isinstance_list, new_list_set_item_op from mypyc.primitives.misc_ops import isinstance_bool from mypyc.primitives.set_ops import isinstance_frozenset, isinstance_set @@ -127,12 +134,25 @@ # compiled, and the RefExpr that is the left hand side of the call. Specializer = Callable[["IRBuilder", CallExpr, RefExpr], Value | None] +# Dunder specializers are for special method calls like __getitem__, __setitem__, etc. +# that don't naturally map to CallExpr nodes (e.g., from IndexExpr). +# +# They take four arguments: the IRBuilder, the base expression (target object), +# the list of argument expressions (positional arguments to the dunder), and the +# context expression (e.g., IndexExpr) for error reporting. +DunderSpecializer = Callable[["IRBuilder", Expression, list[Expression], Expression], Value | None] + # Dictionary containing all configured specializers. # # Specializers can operate on methods as well, and are keyed on the # name and RType in that case. specializers: dict[tuple[str, RType | None], list[Specializer]] = {} +# Dictionary containing all configured dunder specializers. +# +# Dunder specializers are keyed on the dunder name and RType (always a method call). +dunder_specializers: dict[tuple[str, RType], list[DunderSpecializer]] = {} + def _apply_specialization( builder: IRBuilder, expr: CallExpr, callee: RefExpr, name: str | None, typ: RType | None = None @@ -182,6 +202,53 @@ def wrapper(f: Specializer) -> Specializer: return wrapper +def specialize_dunder(name: str, typ: RType) -> Callable[[DunderSpecializer], DunderSpecializer]: + """Decorator to register a function as being a dunder specializer. + + Dunder specializers handle special method calls like __getitem__ that + don't naturally map to CallExpr nodes. + + There may exist multiple specializers for one dunder. When translating + dunder calls, the earlier appended specializer has higher priority. + """ + + def wrapper(f: DunderSpecializer) -> DunderSpecializer: + dunder_specializers.setdefault((name, typ), []).append(f) + return f + + return wrapper + + +def apply_dunder_specialization( + builder: IRBuilder, + base_expr: Expression, + args: list[Expression], + name: str, + ctx_expr: Expression, +) -> Value | None: + """Invoke the DunderSpecializer callback if one has been registered. + + Args: + builder: The IR builder + base_expr: The base expression (target object) + args: List of argument expressions (positional arguments to the dunder) + name: The dunder method name (e.g., "__getitem__") + ctx_expr: The context expression for error reporting (e.g., IndexExpr) + + Returns: + The specialized value, or None if no specialization was found. + """ + base_type = builder.node_type(base_expr) + + # Check if there's a specializer for this dunder method and type + if (name, base_type) in dunder_specializers: + for specializer in dunder_specializers[name, base_type]: + val = specializer(builder, base_expr, args, ctx_expr) + if val is not None: + return val + return None + + @specialize_function("builtins.globals") def translate_globals(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: if len(expr.args) == 0: @@ -1137,3 +1204,98 @@ def translate_object_setattr(builder: IRBuilder, expr: CallExpr, callee: RefExpr name_reg = builder.accept(attr_name) return builder.call_c(generic_setattr, [self_reg, name_reg, value], expr.line) + + +@specialize_dunder("__getitem__", bytes_writer_rprimitive) +def translate_bytes_writer_get_item( + builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression +) -> Value | None: + """Optimized BytesWriter.__getitem__ implementation with bounds checking.""" + # Check that we have exactly one argument + if len(args) != 1: + return None + + # Get the BytesWriter object + obj = builder.accept(base_expr) + + # Get the index argument + index = builder.accept(args[0]) + + # Adjust the index (handle negative indices) + adjusted_index = builder.primitive_op( + bytes_writer_adjust_index_op, [obj, index], ctx_expr.line + ) + + # Check if the adjusted index is in valid range + range_check = builder.primitive_op( + bytes_writer_range_check_op, [obj, adjusted_index], ctx_expr.line + ) + + # Create blocks for branching + valid_block = BasicBlock() + invalid_block = BasicBlock() + + builder.add_bool_branch(range_check, valid_block, invalid_block) + + # Handle invalid index - raise IndexError + builder.activate_block(invalid_block) + builder.add( + RaiseStandardError(RaiseStandardError.INDEX_ERROR, "index out of range", ctx_expr.line) + ) + builder.add(Unreachable()) + + # Handle valid index - get the item + builder.activate_block(valid_block) + result = builder.primitive_op( + bytes_writer_get_item_unsafe_op, [obj, adjusted_index], ctx_expr.line + ) + + return result + + +@specialize_dunder("__setitem__", bytes_writer_rprimitive) +def translate_bytes_writer_set_item( + builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression +) -> Value | None: + """Optimized BytesWriter.__setitem__ implementation with bounds checking.""" + # Check that we have exactly two arguments (index and value) + if len(args) != 2: + return None + + # Get the BytesWriter object + obj = builder.accept(base_expr) + + # Get the index and value arguments + index = builder.accept(args[0]) + value = builder.accept(args[1]) + + # Adjust the index (handle negative indices) + adjusted_index = builder.primitive_op( + bytes_writer_adjust_index_op, [obj, index], ctx_expr.line + ) + + # Check if the adjusted index is in valid range + range_check = builder.primitive_op( + bytes_writer_range_check_op, [obj, adjusted_index], ctx_expr.line + ) + + # Create blocks for branching + valid_block = BasicBlock() + invalid_block = BasicBlock() + + builder.add_bool_branch(range_check, valid_block, invalid_block) + + # Handle invalid index - raise IndexError + builder.activate_block(invalid_block) + builder.add( + RaiseStandardError(RaiseStandardError.INDEX_ERROR, "index out of range", ctx_expr.line) + ) + builder.add(Unreachable()) + + # Handle valid index - set the item + builder.activate_block(valid_block) + builder.primitive_op( + bytes_writer_set_item_unsafe_op, [obj, adjusted_index, value], ctx_expr.line + ) + + return builder.none() diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 6f0db9432f08..d306124d8c0f 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -29,6 +29,7 @@ Import, ImportAll, ImportFrom, + IndexExpr, ListExpr, Lvalue, MatchStmt, @@ -92,6 +93,7 @@ TryFinallyNonlocalControl, ) from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME +from mypyc.irbuild.specialize import apply_dunder_specialization from mypyc.irbuild.targets import ( AssignmentTarget, AssignmentTargetAttr, @@ -260,6 +262,15 @@ def transform_assignment_stmt(builder: IRBuilder, stmt: AssignmentStmt) -> None: return for lvalue in lvalues: + # Check for __setitem__ dunder specialization before converting to assignment target + if isinstance(lvalue, IndexExpr): + specialized = apply_dunder_specialization( + builder, lvalue.base, [lvalue.index, stmt.rvalue], "__setitem__", lvalue + ) + if specialized is not None: + builder.flush_keep_alives() + continue + target = builder.get_assignment_target(lvalue) builder.assign(target, rvalue_reg, line) builder.flush_keep_alives() diff --git a/mypyc/lib-rt/byteswriter_extra_ops.h b/mypyc/lib-rt/byteswriter_extra_ops.h index 885dfe082624..59410a077fcd 100644 --- a/mypyc/lib-rt/byteswriter_extra_ops.h +++ b/mypyc/lib-rt/byteswriter_extra_ops.h @@ -36,6 +36,26 @@ CPyBytesWriter_Append(PyObject *obj, uint8_t value) { char CPyBytesWriter_Write(PyObject *obj, PyObject *value); +// If index is negative, convert to non-negative index (no range checking) +static inline int64_t CPyBytesWriter_AdjustIndex(PyObject *obj, int64_t index) { + if (index < 0) { + return index + ((BytesWriterObject *)obj)->len; + } + return index; +} + +static inline bool CPyBytesWriter_RangeCheck(PyObject *obj, int64_t index) { + return index >= 0 && index < ((BytesWriterObject *)obj)->len; +} + +static inline uint8_t CPyBytesWriter_GetItem(PyObject *obj, int64_t index) { + return (((BytesWriterObject *)obj)->buf)[index]; +} + +static inline void CPyBytesWriter_SetItem(PyObject *obj, int64_t index, uint8_t x) { + (((BytesWriterObject *)obj)->buf)[index] = x; +} + #endif // MYPYC_EXPERIMENTAL #endif diff --git a/mypyc/primitives/librt_strings_ops.py b/mypyc/primitives/librt_strings_ops.py index 1120254e24ae..ac1aa7da1fbe 100644 --- a/mypyc/primitives/librt_strings_ops.py +++ b/mypyc/primitives/librt_strings_ops.py @@ -1,18 +1,16 @@ -from typing import Final - from mypyc.ir.deps import BYTES_WRITER_EXTRA_OPS, LIBRT_STRINGS from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER from mypyc.ir.rtypes import ( - KNOWN_NATIVE_TYPES, + bool_rprimitive, bytes_rprimitive, + bytes_writer_rprimitive, int64_rprimitive, none_rprimitive, short_int_rprimitive, uint8_rprimitive, + void_rtype, ) -from mypyc.primitives.registry import function_op, method_op - -bytes_writer_rprimitive: Final = KNOWN_NATIVE_TYPES["librt.strings.BytesWriter"] +from mypyc.primitives.registry import custom_primitive_op, function_op, method_op function_op( name="librt.strings.BytesWriter", @@ -73,3 +71,47 @@ experimental=True, dependencies=[LIBRT_STRINGS, BYTES_WRITER_EXTRA_OPS], ) + +# BytesWriter index adjustment - convert negative index to positive +bytes_writer_adjust_index_op = custom_primitive_op( + name="bytes_writer_adjust_index", + arg_types=[bytes_writer_rprimitive, int64_rprimitive], + return_type=int64_rprimitive, + c_function_name="CPyBytesWriter_AdjustIndex", + error_kind=ERR_NEVER, + experimental=True, + dependencies=[LIBRT_STRINGS, BYTES_WRITER_EXTRA_OPS], +) + +# BytesWriter range check - check if index is in valid range +bytes_writer_range_check_op = custom_primitive_op( + name="bytes_writer_range_check", + arg_types=[bytes_writer_rprimitive, int64_rprimitive], + return_type=bool_rprimitive, + c_function_name="CPyBytesWriter_RangeCheck", + error_kind=ERR_NEVER, + experimental=True, + dependencies=[LIBRT_STRINGS, BYTES_WRITER_EXTRA_OPS], +) + +# BytesWriter.__getitem__() - get byte at index (no bounds checking) +bytes_writer_get_item_unsafe_op = custom_primitive_op( + name="bytes_writer_get_item", + arg_types=[bytes_writer_rprimitive, int64_rprimitive], + return_type=uint8_rprimitive, + c_function_name="CPyBytesWriter_GetItem", + error_kind=ERR_NEVER, + experimental=True, + dependencies=[LIBRT_STRINGS, BYTES_WRITER_EXTRA_OPS], +) + +# BytesWriter.__setitem__() - set byte at index (no bounds checking) +bytes_writer_set_item_unsafe_op = custom_primitive_op( + name="bytes_writer_set_item", + arg_types=[bytes_writer_rprimitive, int64_rprimitive, uint8_rprimitive], + return_type=void_rtype, + c_function_name="CPyBytesWriter_SetItem", + error_kind=ERR_NEVER, + experimental=True, + dependencies=[LIBRT_STRINGS, BYTES_WRITER_EXTRA_OPS], +) diff --git a/mypyc/test-data/irbuild-librt-strings.test b/mypyc/test-data/irbuild-librt-strings.test index b61649844de2..c3f58febd33d 100644 --- a/mypyc/test-data/irbuild-librt-strings.test +++ b/mypyc/test-data/irbuild-librt-strings.test @@ -11,6 +11,10 @@ def bytes_writer_basics() -> bytes: return b.getvalue() def bytes_writer_len(b: BytesWriter) -> i64: return len(b) +def bytes_writer_get_item(b: BytesWriter, i: i64) -> u8: + return b[i] +def bytes_writer_set_item(b: BytesWriter, i: i64, x: u8) -> None: + b[i] = x [out] def bytes_writer_basics(): r0, b :: librt.strings.BytesWriter @@ -58,3 +62,34 @@ L0: r0 = CPyBytesWriter_Len(b) r1 = r0 >> 1 return r1 +def bytes_writer_get_item(b, i): + b :: librt.strings.BytesWriter + i, r0 :: i64 + r1, r2 :: bool + r3 :: u8 +L0: + r0 = CPyBytesWriter_AdjustIndex(b, i) + r1 = CPyBytesWriter_RangeCheck(b, r0) + if r1 goto L2 else goto L1 :: bool +L1: + r2 = raise IndexError('index out of range') + unreachable +L2: + r3 = CPyBytesWriter_GetItem(b, r0) + return r3 +def bytes_writer_set_item(b, i, x): + b :: librt.strings.BytesWriter + i :: i64 + x :: u8 + r0 :: i64 + r1, r2 :: bool +L0: + r0 = CPyBytesWriter_AdjustIndex(b, i) + r1 = CPyBytesWriter_RangeCheck(b, r0) + if r1 goto L2 else goto L1 :: bool +L1: + r2 = raise IndexError('index out of range') + unreachable +L2: + CPyBytesWriter_SetItem(b, r0, x) + return 1 diff --git a/mypyc/test-data/run-librt-strings.test b/mypyc/test-data/run-librt-strings.test index 8a68f52e60c5..f3e0b7b13100 100644 --- a/mypyc/test-data/run-librt-strings.test +++ b/mypyc/test-data/run-librt-strings.test @@ -56,9 +56,9 @@ def test_bytes_writer_set_item() -> None: with assertRaises(IndexError): w[-(1 << 50)] = 0 - with assertRaises(TypeError): + with assertRaises(ValueError): w[0] = int() - 1 - with assertRaises(TypeError): + with assertRaises(ValueError): w[0] = int() + 256 # Grow BytesWriter