diff --git a/.gitignore b/.gitignore index 0d13d2dd3..55399e692 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,5 @@ devenv.local.nix **/.benchmarks/ .vscode/settings.json guppy-exports/ + +AGENTS.md diff --git a/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py b/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py index 39603c6d3..0139e91fd 100644 --- a/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py @@ -91,6 +91,7 @@ def check_cfg( `first_modifier_node`: if None, the cfg is not a modifier block. Otherwise, it's the AST node of the first modifier, used in error reporting. """ + # First, we need to run program analysis ass_before = {v.name for v in inputs} inout_vars = [v for v in inputs if InputFlags.Inout in v.flags] diff --git a/guppylang-internals/src/guppylang_internals/checker/errors/linearity.py b/guppylang-internals/src/guppylang_internals/checker/errors/linearity.py index 6d47fc42d..1142bd417 100644 --- a/guppylang-internals/src/guppylang_internals/checker/errors/linearity.py +++ b/guppylang-internals/src/guppylang_internals/checker/errors/linearity.py @@ -55,6 +55,26 @@ class MakeCopy(Help): ) +@dataclass(frozen=True) +class ModifiedVariableUsedError(Error): + title: ClassVar[str] = "Variable modified in modifier block" + span_label: ClassVar[str] = ( + "Cannot use `{place}` because it was modified inside a modifier block" + ) + place: Place + + @dataclass(frozen=True) + class ModifiedHere(Note): + span_label: ClassVar[str] = "`{place}` modified here" + place: Place + + @dataclass(frozen=True) + class Explanation(Help): + message: ClassVar[str] = ( + "Modifications inside modifier blocks are not reflected outside the block" + ) + + @dataclass(frozen=True) class ComprAlreadyUsedError(Error): title: ClassVar[str] = "Copy violation" diff --git a/guppylang-internals/src/guppylang_internals/checker/linearity_checker.py b/guppylang-internals/src/guppylang_internals/checker/linearity_checker.py index fe5794069..437769799 100644 --- a/guppylang-internals/src/guppylang_internals/checker/linearity_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/linearity_checker.py @@ -36,6 +36,7 @@ ComprAlreadyUsedError, DropAfterCallError, InCallArg, + ModifiedVariableUsedError, MoveOutOfSubscriptError, NonCopyableCaptureError, NonCopyablePartialApplyError, @@ -101,6 +102,9 @@ class UseKind(Enum): #: An owned value is renamed or stored in a tuple/list MOVE = auto() + #: A captured value is assigned inside a modifier block + DEFINED_IN_MODIFIER = auto() + @property def indicative(self) -> str: """Describes a use in an indicative mood. @@ -126,6 +130,8 @@ def subjunctive(self) -> str: return "returned" case UseKind.MOVE: return "moved" + case UseKind.DEFINED_IN_MODIFIER: + return "modified" class Use(NamedTuple): @@ -298,9 +304,11 @@ def visit_PlaceNode( self.visit(subscript.getitem_call) # For all other places, we record uses of all leaves else: + # Check each leaf separately so we catch partial moves, e.g. struct fields for place in leaf_places(node.place): x = place.id - if (prev_use := self.scope.used(x)) and not place.ty.copyable: + prev_use = self.scope.used(x) + if prev_use and not place.ty.copyable: # When the user's expression (node.place) differs from the # conflicting leaf (place), report the error about the parent # and explain which child was already moved. @@ -319,6 +327,15 @@ def visit_PlaceNode( if has_explicit_copy(place.ty): err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None)) raise GuppyError(err) + # A modifier-block assignment makes the original binding stale, so + # any subsequent use in the same scope is rejected. + if prev_use and prev_use.kind == UseKind.DEFINED_IN_MODIFIER: + err = ModifiedVariableUsedError(node, place) + err.add_sub_diagnostic( + ModifiedVariableUsedError.ModifiedHere(prev_use.node, place) + ) + err.add_sub_diagnostic(ModifiedVariableUsedError.Explanation(None)) + raise GuppyError(err) self.scope.use(x, node, use_kind) def visit_Assign(self, node: ast.Assign) -> None: @@ -685,7 +702,7 @@ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None: # body(q1, q2, ...) # ``` - # check control + # Check control for ctrl in node.control: for arg in ctrl.ctrl: if isinstance(arg, PlaceNode): @@ -696,7 +713,7 @@ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None: unnamed_err.add_sub_diagnostic(UnnamedExprNotUsedError.Fix(None)) raise GuppyTypeError(unnamed_err) - # check power + # Check power for power in node.power: if isinstance(power.iter, PlaceNode): self.visit_PlaceNode( @@ -705,22 +722,34 @@ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None: else: self.visit(power.iter) - # check captured variables + # Check captured variables: + # We check that the modifier is not using consumed variables or copyable + # variables already used by other modifiers for var, use in node.captured.values(): for place in leaf_places(var): use_kind = ( UseKind.BORROW if InputFlags.Inout in var.flags else UseKind.CONSUME ) - x = place.id - if (prev_use := self.scope.used(x)) and not place.ty.copyable: - used_err = AlreadyUsedError(use, place, use_kind) - used_err.add_sub_diagnostic( - AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind) - ) - if has_explicit_copy(place.ty): - used_err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None)) - raise GuppyError(used_err) + if prev_use := self.scope.used(x): + if not place.ty.copyable: + used_err = AlreadyUsedError(use, place, use_kind) + used_err.add_sub_diagnostic( + AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind) + ) + if has_explicit_copy(place.ty): + used_err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None)) + raise GuppyError(used_err) + if prev_use.kind == UseKind.DEFINED_IN_MODIFIER: + err = ModifiedVariableUsedError(use, place) + err.add_sub_diagnostic( + ModifiedVariableUsedError.ModifiedHere(prev_use.node, place) + ) + err.add_sub_diagnostic( + ModifiedVariableUsedError.Explanation(None) + ) + raise GuppyError(err) + self.scope.use(x, node, use_kind) # reassign controls @@ -734,6 +763,12 @@ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None: if InputFlags.Inout in var.flags: self._reassign_single_inout_arg(var, var.defined_at or use) + for name, assignment in node.modified_captured.items(): + var, _ = node.captured[name] + if var.ty.copyable: + for place in leaf_places(var): + self.scope.use(place.id, assignment, UseKind.DEFINED_IN_MODIFIER) + def leaf_places(place: Place) -> Iterator[Place]: """Returns all leaf descendant projections of a place.""" @@ -889,36 +924,51 @@ def check_cfg_linearity( for bb, scope in scopes.items(): live_before_bb = live_before[bb] - # We have to check that used not copyable variables are not being outputted + # Check that values made unusable in this block are not live in successors. + # This catches both non-copyable values that were already used and copyable + # captures that were re-assigned inside a modifier block. for succ in bb.successors: live = live_before[succ] for x, use_bb in live.items(): use_scope = scopes[use_bb] place = use_scope[x] - if not place.ty.copyable and (prev_use := scope.used(x)): - use = use_scope.used_parent[x] - # Special case if this is a use arising from the implicit returning - # of a borrowed argument - if isinstance(use.node, InoutReturnSentinel): - assert isinstance(use.node.var, Variable) - assert InputFlags.Inout in use.node.var.flags - err: Error = BorrowSubPlaceUsedError( - use.node.var.defined_at, use.node.var, place + if prev_use := scope.used(x): + # first we check for variable non-copyable variable + if not place.ty.copyable: + use = use_scope.used_parent[x] + # Special case if this is a use arising from the implicit + # returning of a borrowed argument + if isinstance(use.node, InoutReturnSentinel): + assert isinstance(use.node.var, Variable) + assert InputFlags.Inout in use.node.var.flags + err: Error = BorrowSubPlaceUsedError( + use.node.var.defined_at, use.node.var, place + ) + err.add_sub_diagnostic( + BorrowSubPlaceUsedError.PrevUse( + prev_use.node, prev_use.kind + ) + ) + err.add_sub_diagnostic(BorrowSubPlaceUsedError.Fix(None)) + raise GuppyError(err) + err = AlreadyUsedError(use.node, place, use.kind) + err.add_sub_diagnostic( + AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind) ) + if has_explicit_copy(place.ty): + err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None)) + raise GuppyError(err) + # then we check for variable used inside the modifier block + if prev_use.kind == UseKind.DEFINED_IN_MODIFIER: + use = use_scope.used_parent[x] + err = ModifiedVariableUsedError(use.node, place) err.add_sub_diagnostic( - BorrowSubPlaceUsedError.PrevUse( - prev_use.node, prev_use.kind - ) + ModifiedVariableUsedError.ModifiedHere(prev_use.node, place) + ) + err.add_sub_diagnostic( + ModifiedVariableUsedError.Explanation(None) ) - err.add_sub_diagnostic(BorrowSubPlaceUsedError.Fix(None)) raise GuppyError(err) - err = AlreadyUsedError(use.node, place, use.kind) - err.add_sub_diagnostic( - AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind) - ) - if has_explicit_copy(place.ty): - err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None)) - raise GuppyError(err) # On the other hand, unused variables that are not droppable *must* be outputted for place in scope.values(): diff --git a/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py b/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py index 26439bd72..367ad7ed4 100644 --- a/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py @@ -1,8 +1,9 @@ """Type checking code for modifiers.""" import ast +from collections.abc import Collection -from guppylang_internals.ast_util import loop_in_ast, with_loc +from guppylang_internals.ast_util import AstNode, loop_in_ast, with_loc from guppylang_internals.cfg.bb import BB from guppylang_internals.checker.cfg_checker import check_cfg from guppylang_internals.checker.core import Context, Variable @@ -36,6 +37,7 @@ def check_modified_block( for x, using_bb in cfg.live_before[cfg.entry_bb].items() if x in ctx.locals } + modified_captured = _modified_captured_vars(modified_block, captured.keys()) # We do not allow any assignments if it is daggered. if modified_block.has_dagger(): @@ -83,12 +85,37 @@ def check_modified_block( checked_cfg, func_ty, captured, + modified_captured, modified_block.modifiers, **dict(ast.iter_fields(modified_block)), ) return with_loc(modified_block, checked_modifier) +def _modified_captured_vars( + modified_block: ModifiedBlock, captured_names: Collection[str] +) -> dict[str, AstNode]: + """Find captured variables assigned anywhere in a modifier body.""" + modified = {} + for body_bb in modified_block.cfg.bbs: + modified.update( + { + x: assignment + for x, assignment in body_bb.vars.assigned.items() + if x in captured_names + } + ) + for stmt in body_bb.statements: + assert not isinstance(stmt, CheckedModifiedBlock), ( + "CheckedModifiedBlocks should not be present while checking the cfg" + ) + if isinstance(stmt, ModifiedBlock): + # Recursively check nested modified blocks + modified.update(_modified_captured_vars(stmt, captured_names)) + + return modified + + def _set_inout_if_non_copyable(var: Variable) -> Variable: """Set the `inout` flag if the variable is non-copyable.""" if not var.ty.copyable: diff --git a/guppylang-internals/src/guppylang_internals/checker/stmt_checker.py b/guppylang-internals/src/guppylang_internals/checker/stmt_checker.py index 498faefb0..2f9922614 100644 --- a/guppylang-internals/src/guppylang_internals/checker/stmt_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/stmt_checker.py @@ -420,10 +420,10 @@ def visit_ModifiedBlock(self, node: ModifiedBlock) -> ast.stmt: raise InternalGuppyError("BB required to check with block!") # check the body of the modified block - modified_block = check_modified_block(node, self.bb, self.ctx) + checked_modified_block = check_modified_block(node, self.bb, self.ctx) # check the arguments of the control and power. - for control in modified_block.control: + for control in checked_modified_block.control: ctrl = control.ctrl # This case is handled during CFG construction. assert len(ctrl) > 0 @@ -447,13 +447,13 @@ def visit_ModifiedBlock(self, node: ModifiedBlock) -> ast.stmt: assert len(subst) == 0 control.qubit_num = len(ctrl) - for power in node.power: + for power in checked_modified_block.power: power.iter, subst = self._check_expr( power.iter, NumericType(NumericType.Kind.Nat) ) assert len(subst) == 0 - return modified_block + return checked_modified_block def visit_If(self, node: ast.If) -> None: raise InternalGuppyError("Control-flow statement should not be present here.") diff --git a/guppylang-internals/src/guppylang_internals/nodes.py b/guppylang-internals/src/guppylang_internals/nodes.py index 605aee020..9ee604865 100644 --- a/guppylang-internals/src/guppylang_internals/nodes.py +++ b/guppylang-internals/src/guppylang_internals/nodes.py @@ -887,6 +887,8 @@ class CheckedModifiedBlock(ast.With): ty: FunctionType #: Mapping from names to variables captured in the body. captured: Mapping[str, tuple["Variable", AstNode]] + #: Mapping from captured names to assignments inside the body. + modified_captured: Mapping[str, AstNode] def __init__( self, @@ -894,6 +896,7 @@ def __init__( cfg: "CheckedCFG[Place]", ty: FunctionType, captured: Mapping[str, tuple["Variable", AstNode]], + modified_captured: Mapping[str, AstNode], modifiers: Modifiers, *args: Any, **kwargs: Any, @@ -903,6 +906,7 @@ def __init__( self.cfg = cfg self.ty = ty self.captured = captured + self.modified_captured = modified_captured self.modifiers = modifiers @property diff --git a/tests/error/errors_on_usage/branch_in_modifier.err b/tests/error/errors_on_usage/branch_in_modifier.err new file mode 100644 index 000000000..d887d594c --- /dev/null +++ b/tests/error/errors_on_usage/branch_in_modifier.err @@ -0,0 +1,8 @@ +Error: Variable not defined (at $FILE:10:8) + | + 8 | x = 3 + 9 | a = b +10 | c = x +1 + | ^ `x` is not defined + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/errors_on_usage/branch_in_modifier.py b/tests/error/errors_on_usage/branch_in_modifier.py new file mode 100644 index 000000000..064210299 --- /dev/null +++ b/tests/error/errors_on_usage/branch_in_modifier.py @@ -0,0 +1,14 @@ +from guppylang.decorator import guppy +from guppylang.std.builtins import power + +@guppy +def test(b: bool) -> int: + with power(2): + if b: + x = 3 + a = b + c = x +1 + return 0 + + +test.check() diff --git a/tests/error/modifier_errors/captured_classical_modified.err b/tests/error/modifier_errors/captured_classical_modified.err new file mode 100644 index 000000000..7d13a5f2b --- /dev/null +++ b/tests/error/modifier_errors/captured_classical_modified.err @@ -0,0 +1,17 @@ +Error: Variable modified in modifier block (at $FILE:10:11) + | + 8 | with power(2): + 9 | x += 1 +10 | return x + | ^ Cannot use `x` because it was modified inside a modifier + | block + +Note: + | + 8 | with power(2): + 9 | x += 1 + | ------ `x` modified here + +Help: Modifications inside modifier blocks are not reflected outside the block + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/captured_classical_modified.py b/tests/error/modifier_errors/captured_classical_modified.py new file mode 100644 index 000000000..de49b4eba --- /dev/null +++ b/tests/error/modifier_errors/captured_classical_modified.py @@ -0,0 +1,13 @@ +from guppylang.decorator import guppy +from guppylang.std.builtins import power + + +@guppy +def test() -> int: + x = 3 + with power(2): + x += 1 + return x + + +test.compile() diff --git a/tests/error/modifier_errors/captured_classical_modified_branch.err b/tests/error/modifier_errors/captured_classical_modified_branch.err new file mode 100644 index 000000000..170fd8542 --- /dev/null +++ b/tests/error/modifier_errors/captured_classical_modified_branch.err @@ -0,0 +1,17 @@ +Error: Variable modified in modifier block (at $FILE:11:15) + | + 9 | x += 1 +10 | if b: +11 | return x + | ^ Cannot use `x` because it was modified inside a modifier + | block + +Note: + | + 8 | with power(2): + 9 | x += 1 + | ------ `x` modified here + +Help: Modifications inside modifier blocks are not reflected outside the block + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/captured_classical_modified_branch.py b/tests/error/modifier_errors/captured_classical_modified_branch.py new file mode 100644 index 000000000..34df1562d --- /dev/null +++ b/tests/error/modifier_errors/captured_classical_modified_branch.py @@ -0,0 +1,15 @@ +from guppylang.decorator import guppy +from guppylang.std.builtins import power + + +@guppy +def test(b: bool) -> int: + x = 3 + with power(2): + x += 1 + if b: + return x + return 0 + + +test.compile_function() diff --git a/tests/error/modifier_errors/captured_classical_modified_multiple.err b/tests/error/modifier_errors/captured_classical_modified_multiple.err new file mode 100644 index 000000000..a27b7ee63 --- /dev/null +++ b/tests/error/modifier_errors/captured_classical_modified_multiple.err @@ -0,0 +1,17 @@ +Error: Variable modified in modifier block (at $FILE:13:11) + | +11 | x += 1 +12 | y -= 2 +13 | return x + y + | ^ Cannot use `x` because it was modified inside a modifier + | block + +Note: + | +10 | with power(3): +11 | x += 1 + | ------ `x` modified here + +Help: Modifications inside modifier blocks are not reflected outside the block + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/captured_classical_modified_multiple.py b/tests/error/modifier_errors/captured_classical_modified_multiple.py new file mode 100644 index 000000000..6b388e8b8 --- /dev/null +++ b/tests/error/modifier_errors/captured_classical_modified_multiple.py @@ -0,0 +1,16 @@ +from guppylang.decorator import guppy +from guppylang.std.builtins import power, control, qubit + + +@guppy +def test(q: qubit) -> int: + x = 1 + y = 10 + with control(q): # noqa: SIM117 + with power(3): + x += 1 + y -= 2 + return x + y + + +test.compile() diff --git a/tests/error/modifier_errors/captured_classical_modified_nested.err b/tests/error/modifier_errors/captured_classical_modified_nested.err new file mode 100644 index 000000000..9b7511e32 --- /dev/null +++ b/tests/error/modifier_errors/captured_classical_modified_nested.err @@ -0,0 +1,17 @@ +Error: Variable modified in modifier block (at $FILE:12:11) + | +10 | with control(q): +11 | x += 1 +12 | return x + | ^ Cannot use `x` because it was modified inside a modifier + | block + +Note: + | +10 | with control(q): +11 | x += 1 + | ------ `x` modified here + +Help: Modifications inside modifier blocks are not reflected outside the block + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/captured_classical_modified_nested.py b/tests/error/modifier_errors/captured_classical_modified_nested.py new file mode 100644 index 000000000..68031289b --- /dev/null +++ b/tests/error/modifier_errors/captured_classical_modified_nested.py @@ -0,0 +1,15 @@ +from guppylang.decorator import guppy +from guppylang.std.builtins import power, control, qubit +from guppylang.std.num import nat + + +@guppy +def test(n: nat, q: qubit) -> int: + x = 0 + with power(n): # noqa: SIM117 + with control(q): + x += 1 + return x + + +test.compile_function() diff --git a/tests/error/modifier_errors/captured_classical_modified_sequential.err b/tests/error/modifier_errors/captured_classical_modified_sequential.err new file mode 100644 index 000000000..6e7138bfb --- /dev/null +++ b/tests/error/modifier_errors/captured_classical_modified_sequential.err @@ -0,0 +1,17 @@ +Error: Variable modified in modifier block (at $FILE:11:8) + | + 9 | x += 1 +10 | with power(3): +11 | x += 2 + | ^ Cannot use `x` because it was modified inside a modifier + | block + +Note: + | + 8 | with power(2): + 9 | x += 1 + | ------ `x` modified here + +Help: Modifications inside modifier blocks are not reflected outside the block + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/captured_classical_modified_sequential.py b/tests/error/modifier_errors/captured_classical_modified_sequential.py new file mode 100644 index 000000000..6b1608486 --- /dev/null +++ b/tests/error/modifier_errors/captured_classical_modified_sequential.py @@ -0,0 +1,15 @@ +from guppylang.decorator import guppy +from guppylang.std.builtins import power + + +@guppy +def test() -> int: + x = 3 + with power(2): + x += 1 + with power(3): + x += 2 + return x + + +test.compile()