Skip to content

[mypyc] Add SetElement op for initializing struct values #19437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mypyc/analysis/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@
RegisterOp,
Return,
SetAttr,
SetElement,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Undef,
Unreachable,
Value,
)
Expand Down Expand Up @@ -272,6 +274,9 @@ def visit_load_mem(self, op: LoadMem) -> GenAndKill[T]:
def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill[T]:
return self.visit_register_op(op)

def visit_set_element(self, op: SetElement) -> GenAndKill[T]:
return self.visit_register_op(op)

def visit_load_address(self, op: LoadAddress) -> GenAndKill[T]:
return self.visit_register_op(op)

Expand Down Expand Up @@ -444,7 +449,7 @@ def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
def non_trivial_sources(op: Op) -> set[Value]:
result = set()
for source in op.sources():
if not isinstance(source, (Integer, Float)):
if not isinstance(source, (Integer, Float, Undef)):
result.add(source)
return result

Expand Down
8 changes: 7 additions & 1 deletion mypyc/analysis/ircheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ControlOp,
DecRef,
Extend,
Float,
FloatComparisonOp,
FloatNeg,
FloatOp,
Expand All @@ -42,12 +43,14 @@
Register,
Return,
SetAttr,
SetElement,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Undef,
Unreachable,
Value,
)
Expand Down Expand Up @@ -148,7 +151,7 @@ def check_op_sources_valid(fn: FuncIR) -> list[FnError]:
for block in fn.blocks:
for op in block.ops:
for source in op.sources():
if isinstance(source, Integer):
if isinstance(source, (Integer, Float, Undef)):
pass
elif isinstance(source, Op):
if source not in valid_ops:
Expand Down Expand Up @@ -423,6 +426,9 @@ def visit_set_mem(self, op: SetMem) -> None:
def visit_get_element_ptr(self, op: GetElementPtr) -> None:
pass

def visit_set_element(self, op: SetElement) -> None:
pass

def visit_load_address(self, op: LoadAddress) -> None:
pass

Expand Down
4 changes: 4 additions & 0 deletions mypyc/analysis/selfleaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
RegisterOp,
Return,
SetAttr,
SetElement,
SetMem,
Truncate,
TupleGet,
Expand Down Expand Up @@ -181,6 +182,9 @@ def visit_load_mem(self, op: LoadMem) -> GenAndKill:
def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill:
return CLEAN

def visit_set_element(self, op: SetElement) -> GenAndKill:
return CLEAN

def visit_load_address(self, op: LoadAddress) -> GenAndKill:
return CLEAN

Expand Down
27 changes: 27 additions & 0 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@
Register,
Return,
SetAttr,
SetElement,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Undef,
Unreachable,
Value,
)
Expand Down Expand Up @@ -813,6 +815,31 @@ def visit_get_element_ptr(self, op: GetElementPtr) -> None:
)
)

def visit_set_element(self, op: SetElement) -> None:
dest = self.reg(op)
item = self.reg(op.item)
field = op.field
if isinstance(op.src, Undef):
# First assignment to an undefined struct is trivial.
self.emit_line(f"{dest}.{field} = {item};")
else:
# In the general case create a copy of the struct with a single
# item modified.
#
# TODO: Can we do better if only a subset of fields are initialized?
# TODO: Make this less verbose in the common case
# TODO: Support tuples (or use RStruct for tuples)?
src = self.reg(op.src)
src_type = op.src.type
assert isinstance(src_type, RStruct), src_type
init_items = []
for n in src_type.names:
if n != field:
init_items.append(f"{src}.{n}")
else:
init_items.append(item)
self.emit_line(f"{dest} = ({self.ctype(src_type)}) {{ {', '.join(init_items)} }};")

def visit_load_address(self, op: LoadAddress) -> None:
typ = op.type
dest = self.reg(op)
Expand Down
58 changes: 58 additions & 0 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class to enable the new behavior. Sometimes adding a new abstract
from mypyc.ir.rtypes import (
RArray,
RInstance,
RStruct,
RTuple,
RType,
RVoid,
Expand Down Expand Up @@ -244,6 +245,26 @@ def __init__(self, value: bytes, line: int = -1) -> None:
self.line = line


@final
class Undef(Value):
"""An undefined value.

Use Undef() as the initial value followed by one or more SetElement
ops to initialize a struct. Pseudocode example:

r0 = set_element undef MyStruct, "field1", f1
r1 = set_element r0, "field2", f2
# r1 now has new struct value with two fields set

Warning: Always initialize undefined values before using them,
as otherwise the values are garbage. You shouldn't expect that
undefined values are zeroed, in particular.
"""

def __init__(self, rtype: RType) -> None:
self.type = rtype


class Op(Value):
"""Abstract base class for all IR operations.

Expand Down Expand Up @@ -1636,6 +1657,39 @@ def accept(self, visitor: OpVisitor[T]) -> T:
return visitor.visit_get_element_ptr(self)


@final
class SetElement(RegisterOp):
"""Set the value of a struct element.

This evaluates to a new struct with the changed value.

Use together with Undef to initialize a fresh struct value
(see Undef for more details).
"""

error_kind = ERR_NEVER

def __init__(self, src: Value, field: str, item: Value, line: int = -1) -> None:
super().__init__(line)
assert isinstance(src.type, RStruct), src.type
self.type = src.type
Copy link

@peturingi peturingi Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Derivable from src; @Property?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Irrelevant as it won't play nice with

def visit_set_element(self, op: SetElement) -> None:
        op.src = self.fix_op(op.src)

self.src = src
self.item = item
self.field = field

def sources(self) -> list[Value]:
return [self.src]

def set_sources(self, new: list[Value]) -> None:
(self.src,) = new

def stolen(self) -> list[Value]:
return [self.src]

def accept(self, visitor: OpVisitor[T]) -> T:
return visitor.visit_set_element(self)


@final
class LoadAddress(RegisterOp):
"""Get the address of a value: result = (type)&src
Expand Down Expand Up @@ -1908,6 +1962,10 @@ def visit_set_mem(self, op: SetMem) -> T:
def visit_get_element_ptr(self, op: GetElementPtr) -> T:
raise NotImplementedError

@abstractmethod
def visit_set_element(self, op: SetElement) -> T:
raise NotImplementedError

@abstractmethod
def visit_load_address(self, op: LoadAddress) -> T:
raise NotImplementedError
Expand Down
9 changes: 8 additions & 1 deletion mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@
Register,
Return,
SetAttr,
SetElement,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Undef,
Unreachable,
Value,
)
Expand Down Expand Up @@ -273,6 +275,9 @@ def visit_set_mem(self, op: SetMem) -> str:
def visit_get_element_ptr(self, op: GetElementPtr) -> str:
return self.format("%r = get_element_ptr %r %s :: %t", op, op.src, op.field, op.src_type)

def visit_set_element(self, op: SetElement) -> str:
return self.format("%r = set_element %r, %s, %r", op, op.src, op.field, op.item)

def visit_load_address(self, op: LoadAddress) -> str:
if isinstance(op.src, Register):
return self.format("%r = load_address %r", op, op.src)
Expand Down Expand Up @@ -330,6 +335,8 @@ def format(self, fmt: str, *args: Any) -> str:
result.append(repr(arg.value))
elif isinstance(arg, CString):
result.append(f"CString({arg.value!r})")
elif isinstance(arg, Undef):
result.append(f"undef {arg.type.name}")
else:
result.append(self.names[arg])
elif typespec == "d":
Expand Down Expand Up @@ -486,7 +493,7 @@ def generate_names_for_ir(args: list[Register], blocks: list[BasicBlock]) -> dic
continue
if isinstance(value, Register) and value.name:
name = value.name
elif isinstance(value, (Integer, Float)):
elif isinstance(value, (Integer, Float, Undef)):
continue
else:
name = "r%d" % temp_index
Expand Down
18 changes: 18 additions & 0 deletions mypyc/test/test_emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
Register,
Return,
SetAttr,
SetElement,
SetMem,
TupleGet,
Unbox,
Undef,
Unreachable,
Value,
)
Expand Down Expand Up @@ -121,6 +123,11 @@ def add_local(name: str, rtype: RType) -> Register:
self.r = add_local("r", RInstance(ir))
self.none = add_local("none", none_rprimitive)

self.struct_type = RStruct(
"Foo", ["b", "x", "y"], [bool_rprimitive, int32_rprimitive, int64_rprimitive]
)
self.st = add_local("st", self.struct_type)

self.context = EmitterContext(NameGenerator([["mod"]]))

def test_goto(self) -> None:
Expand Down Expand Up @@ -674,6 +681,17 @@ def test_get_element_ptr(self) -> None:
GetElementPtr(self.o, r, "i64"), """cpy_r_r0 = (CPyPtr)&((Foo *)cpy_r_o)->i64;"""
)

def test_set_element(self) -> None:
# Use compact syntax when setting the initial element of an undefined value
self.assert_emit(
SetElement(Undef(self.struct_type), "b", self.b), """cpy_r_r0.b = cpy_r_b;"""
)
# We propagate the unchanged values in subsequent assignments
self.assert_emit(
SetElement(self.st, "x", self.i32),
"""cpy_r_r0 = (Foo) { cpy_r_st.b, cpy_r_i32, cpy_r_st.y };""",
)

def test_load_address(self) -> None:
self.assert_emit(
LoadAddress(object_rprimitive, "PyDict_Type"),
Expand Down
7 changes: 7 additions & 0 deletions mypyc/transform/ir_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
RaiseStandardError,
Return,
SetAttr,
SetElement,
SetMem,
Truncate,
TupleGet,
Expand Down Expand Up @@ -214,6 +215,9 @@ def visit_set_mem(self, op: SetMem) -> Value | None:
def visit_get_element_ptr(self, op: GetElementPtr) -> Value | None:
return self.add(op)

def visit_set_element(self, op: SetElement) -> Value | None:
return self.add(op)

def visit_load_address(self, op: LoadAddress) -> Value | None:
return self.add(op)

Expand Down Expand Up @@ -354,6 +358,9 @@ def visit_set_mem(self, op: SetMem) -> None:
def visit_get_element_ptr(self, op: GetElementPtr) -> None:
op.src = self.fix_op(op.src)

def visit_set_element(self, op: SetElement) -> None:
op.src = self.fix_op(op.src)

def visit_load_address(self, op: LoadAddress) -> None:
if isinstance(op.src, LoadStatic):
new = self.fix_op(op.src)
Expand Down
3 changes: 2 additions & 1 deletion mypyc/transform/refcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Op,
Register,
RegisterOp,
Undef,
Value,
)

Expand Down Expand Up @@ -94,7 +95,7 @@ def is_maybe_undefined(post_must_defined: set[Value], src: Value) -> bool:
def maybe_append_dec_ref(
ops: list[Op], dest: Value, defined: AnalysisDict[Value], key: tuple[BasicBlock, int]
) -> None:
if dest.type.is_refcounted and not isinstance(dest, Integer):
if dest.type.is_refcounted and not isinstance(dest, (Integer, Undef)):
ops.append(DecRef(dest, is_xdec=is_maybe_undefined(defined[key], dest)))


Expand Down