Skip to content
Merged
1 change: 1 addition & 0 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,7 @@ class RaiseStandardError(RegisterOp):
RUNTIME_ERROR: Final = "RuntimeError"
NAME_ERROR: Final = "NameError"
ZERO_DIVISION_ERROR: Final = "ZeroDivisionError"
INDEX_ERROR: Final = "IndexError"

def __init__(self, class_name: str, value: str | Value | None, line: int) -> None:
super().__init__(line)
Expand Down
2 changes: 2 additions & 0 deletions mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,8 @@ def __hash__(self) -> int:
]
}

bytes_writer_rprimitive: Final = KNOWN_NATIVE_TYPES["librt.strings.BytesWriter"]


def is_native_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name in KNOWN_NATIVE_TYPES
Expand Down
7 changes: 7 additions & 0 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
tokenizer_printf_style,
)
from mypyc.irbuild.specialize import (
apply_dunder_specialization,
apply_function_specialization,
apply_method_specialization,
translate_object_new,
Expand Down Expand Up @@ -587,6 +588,12 @@ def transform_index_expr(builder: IRBuilder, expr: IndexExpr) -> Value:
is_list = is_list_rprimitive(base_type)
can_borrow_base = is_list and is_borrow_friendly_expr(builder, index)

# Check for dunder specialization for non-slice indexing
if not isinstance(index, SliceExpr):
specialized = apply_dunder_specialization(builder, expr.base, [index], "__getitem__", expr)
if specialized is not None:
return specialized

base = builder.accept(expr.base, can_borrow=can_borrow_base)

if isinstance(base.type, RTuple):
Expand Down
162 changes: 162 additions & 0 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
RType,
bool_rprimitive,
bytes_rprimitive,
bytes_writer_rprimitive,
c_int_rprimitive,
dict_rprimitive,
int16_rprimitive,
Expand Down Expand Up @@ -104,6 +105,12 @@
from mypyc.primitives.float_ops import isinstance_float
from mypyc.primitives.generic_ops import generic_setattr, setup_object
from mypyc.primitives.int_ops import isinstance_int
from mypyc.primitives.librt_strings_ops import (
bytes_writer_adjust_index_op,
bytes_writer_get_item_unsafe_op,
bytes_writer_range_check_op,
bytes_writer_set_item_unsafe_op,
)
from mypyc.primitives.list_ops import isinstance_list, new_list_set_item_op
from mypyc.primitives.misc_ops import isinstance_bool
from mypyc.primitives.set_ops import isinstance_frozenset, isinstance_set
Expand All @@ -127,12 +134,25 @@
# compiled, and the RefExpr that is the left hand side of the call.
Specializer = Callable[["IRBuilder", CallExpr, RefExpr], Value | None]

# Dunder specializers are for special method calls like __getitem__, __setitem__, etc.
# that don't naturally map to CallExpr nodes (e.g., from IndexExpr).
#
# They take four arguments: the IRBuilder, the base expression (target object),
# the list of argument expressions (positional arguments to the dunder), and the
# context expression (e.g., IndexExpr) for error reporting.
DunderSpecializer = Callable[["IRBuilder", Expression, list[Expression], Expression], Value | None]

# Dictionary containing all configured specializers.
#
# Specializers can operate on methods as well, and are keyed on the
# name and RType in that case.
specializers: dict[tuple[str, RType | None], list[Specializer]] = {}

# Dictionary containing all configured dunder specializers.
#
# Dunder specializers are keyed on the dunder name and RType (always a method call).
dunder_specializers: dict[tuple[str, RType], list[DunderSpecializer]] = {}


def _apply_specialization(
builder: IRBuilder, expr: CallExpr, callee: RefExpr, name: str | None, typ: RType | None = None
Expand Down Expand Up @@ -182,6 +202,53 @@ def wrapper(f: Specializer) -> Specializer:
return wrapper


def specialize_dunder(name: str, typ: RType) -> Callable[[DunderSpecializer], DunderSpecializer]:
"""Decorator to register a function as being a dunder specializer.

Dunder specializers handle special method calls like __getitem__ that
don't naturally map to CallExpr nodes.

There may exist multiple specializers for one dunder. When translating
dunder calls, the earlier appended specializer has higher priority.
"""

def wrapper(f: DunderSpecializer) -> DunderSpecializer:
dunder_specializers.setdefault((name, typ), []).append(f)
return f

return wrapper


def apply_dunder_specialization(
builder: IRBuilder,
base_expr: Expression,
args: list[Expression],
name: str,
ctx_expr: Expression,
) -> Value | None:
"""Invoke the DunderSpecializer callback if one has been registered.

Args:
builder: The IR builder
base_expr: The base expression (target object)
args: List of argument expressions (positional arguments to the dunder)
name: The dunder method name (e.g., "__getitem__")
ctx_expr: The context expression for error reporting (e.g., IndexExpr)

Returns:
The specialized value, or None if no specialization was found.
"""
base_type = builder.node_type(base_expr)

# Check if there's a specializer for this dunder method and type
if (name, base_type) in dunder_specializers:
for specializer in dunder_specializers[name, base_type]:
val = specializer(builder, base_expr, args, ctx_expr)
if val is not None:
return val
return None


@specialize_function("builtins.globals")
def translate_globals(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
if len(expr.args) == 0:
Expand Down Expand Up @@ -1137,3 +1204,98 @@ def translate_object_setattr(builder: IRBuilder, expr: CallExpr, callee: RefExpr

name_reg = builder.accept(attr_name)
return builder.call_c(generic_setattr, [self_reg, name_reg, value], expr.line)


@specialize_dunder("__getitem__", bytes_writer_rprimitive)
def translate_bytes_writer_get_item(
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
) -> Value | None:
"""Optimized BytesWriter.__getitem__ implementation with bounds checking."""
# Check that we have exactly one argument
if len(args) != 1:
return None

# Get the BytesWriter object
obj = builder.accept(base_expr)

# Get the index argument
index = builder.accept(args[0])

# Adjust the index (handle negative indices)
adjusted_index = builder.primitive_op(
bytes_writer_adjust_index_op, [obj, index], ctx_expr.line
)

# Check if the adjusted index is in valid range
range_check = builder.primitive_op(
bytes_writer_range_check_op, [obj, adjusted_index], ctx_expr.line
)

# Create blocks for branching
valid_block = BasicBlock()
invalid_block = BasicBlock()

builder.add_bool_branch(range_check, valid_block, invalid_block)

# Handle invalid index - raise IndexError
builder.activate_block(invalid_block)
builder.add(
RaiseStandardError(RaiseStandardError.INDEX_ERROR, "index out of range", ctx_expr.line)
)
builder.add(Unreachable())

# Handle valid index - get the item
builder.activate_block(valid_block)
result = builder.primitive_op(
bytes_writer_get_item_unsafe_op, [obj, adjusted_index], ctx_expr.line
)

return result


@specialize_dunder("__setitem__", bytes_writer_rprimitive)
def translate_bytes_writer_set_item(
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
) -> Value | None:
"""Optimized BytesWriter.__setitem__ implementation with bounds checking."""
# Check that we have exactly two arguments (index and value)
if len(args) != 2:
return None

# Get the BytesWriter object
obj = builder.accept(base_expr)

# Get the index and value arguments
index = builder.accept(args[0])
value = builder.accept(args[1])

# Adjust the index (handle negative indices)
adjusted_index = builder.primitive_op(
bytes_writer_adjust_index_op, [obj, index], ctx_expr.line
)

# Check if the adjusted index is in valid range
range_check = builder.primitive_op(
bytes_writer_range_check_op, [obj, adjusted_index], ctx_expr.line
)

# Create blocks for branching
valid_block = BasicBlock()
invalid_block = BasicBlock()

builder.add_bool_branch(range_check, valid_block, invalid_block)

# Handle invalid index - raise IndexError
builder.activate_block(invalid_block)
builder.add(
RaiseStandardError(RaiseStandardError.INDEX_ERROR, "index out of range", ctx_expr.line)
)
builder.add(Unreachable())

# Handle valid index - set the item
builder.activate_block(valid_block)
builder.primitive_op(
bytes_writer_set_item_unsafe_op, [obj, adjusted_index, value], ctx_expr.line
)

return builder.none()
11 changes: 11 additions & 0 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Import,
ImportAll,
ImportFrom,
IndexExpr,
ListExpr,
Lvalue,
MatchStmt,
Expand Down Expand Up @@ -92,6 +93,7 @@
TryFinallyNonlocalControl,
)
from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME
from mypyc.irbuild.specialize import apply_dunder_specialization
from mypyc.irbuild.targets import (
AssignmentTarget,
AssignmentTargetAttr,
Expand Down Expand Up @@ -260,6 +262,15 @@ def transform_assignment_stmt(builder: IRBuilder, stmt: AssignmentStmt) -> None:
return

for lvalue in lvalues:
# Check for __setitem__ dunder specialization before converting to assignment target
if isinstance(lvalue, IndexExpr):
specialized = apply_dunder_specialization(
builder, lvalue.base, [lvalue.index, stmt.rvalue], "__setitem__", lvalue
)
if specialized is not None:
builder.flush_keep_alives()
continue

target = builder.get_assignment_target(lvalue)
builder.assign(target, rvalue_reg, line)
builder.flush_keep_alives()
Expand Down
20 changes: 20 additions & 0 deletions mypyc/lib-rt/byteswriter_extra_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,26 @@ CPyBytesWriter_Append(PyObject *obj, uint8_t value) {

char CPyBytesWriter_Write(PyObject *obj, PyObject *value);

// If index is negative, convert to non-negative index (no range checking)
static inline int64_t CPyBytesWriter_AdjustIndex(PyObject *obj, int64_t index) {
if (index < 0) {
return index + ((BytesWriterObject *)obj)->len;
}
return index;
}

static inline bool CPyBytesWriter_RangeCheck(PyObject *obj, int64_t index) {
return index >= 0 && index < ((BytesWriterObject *)obj)->len;
}

static inline uint8_t CPyBytesWriter_GetItem(PyObject *obj, int64_t index) {
return (((BytesWriterObject *)obj)->buf)[index];
}

static inline void CPyBytesWriter_SetItem(PyObject *obj, int64_t index, uint8_t x) {
(((BytesWriterObject *)obj)->buf)[index] = x;
}

#endif // MYPYC_EXPERIMENTAL

#endif
54 changes: 48 additions & 6 deletions mypyc/primitives/librt_strings_ops.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from typing import Final

from mypyc.ir.deps import BYTES_WRITER_EXTRA_OPS, LIBRT_STRINGS
from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER
from mypyc.ir.rtypes import (
KNOWN_NATIVE_TYPES,
bool_rprimitive,
bytes_rprimitive,
bytes_writer_rprimitive,
int64_rprimitive,
none_rprimitive,
short_int_rprimitive,
uint8_rprimitive,
void_rtype,
)
from mypyc.primitives.registry import function_op, method_op

bytes_writer_rprimitive: Final = KNOWN_NATIVE_TYPES["librt.strings.BytesWriter"]
from mypyc.primitives.registry import custom_primitive_op, function_op, method_op

function_op(
name="librt.strings.BytesWriter",
Expand Down Expand Up @@ -73,3 +71,47 @@
experimental=True,
dependencies=[LIBRT_STRINGS, BYTES_WRITER_EXTRA_OPS],
)

# BytesWriter index adjustment - convert negative index to positive
bytes_writer_adjust_index_op = custom_primitive_op(
name="bytes_writer_adjust_index",
arg_types=[bytes_writer_rprimitive, int64_rprimitive],
return_type=int64_rprimitive,
c_function_name="CPyBytesWriter_AdjustIndex",
error_kind=ERR_NEVER,
experimental=True,
dependencies=[LIBRT_STRINGS, BYTES_WRITER_EXTRA_OPS],
)

# BytesWriter range check - check if index is in valid range
bytes_writer_range_check_op = custom_primitive_op(
name="bytes_writer_range_check",
arg_types=[bytes_writer_rprimitive, int64_rprimitive],
return_type=bool_rprimitive,
c_function_name="CPyBytesWriter_RangeCheck",
error_kind=ERR_NEVER,
experimental=True,
dependencies=[LIBRT_STRINGS, BYTES_WRITER_EXTRA_OPS],
)

# BytesWriter.__getitem__() - get byte at index (no bounds checking)
bytes_writer_get_item_unsafe_op = custom_primitive_op(
name="bytes_writer_get_item",
arg_types=[bytes_writer_rprimitive, int64_rprimitive],
return_type=uint8_rprimitive,
c_function_name="CPyBytesWriter_GetItem",
error_kind=ERR_NEVER,
experimental=True,
dependencies=[LIBRT_STRINGS, BYTES_WRITER_EXTRA_OPS],
)

# BytesWriter.__setitem__() - set byte at index (no bounds checking)
bytes_writer_set_item_unsafe_op = custom_primitive_op(
name="bytes_writer_set_item",
arg_types=[bytes_writer_rprimitive, int64_rprimitive, uint8_rprimitive],
return_type=void_rtype,
c_function_name="CPyBytesWriter_SetItem",
error_kind=ERR_NEVER,
experimental=True,
dependencies=[LIBRT_STRINGS, BYTES_WRITER_EXTRA_OPS],
)
Loading