Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,5 @@ devenv.local.nix
**/.benchmarks/
.vscode/settings.json
guppy-exports/

AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def check_cfg(
`first_modifier_node`: if None, the cfg is not a modifier block.
Otherwise, it's the AST node of the first modifier, used in error reporting.
"""

# First, we need to run program analysis
ass_before = {v.name for v in inputs}
inout_vars = [v for v in inputs if InputFlags.Inout in v.flags]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,26 @@ class MakeCopy(Help):
)


@dataclass(frozen=True)
class ModifiedVariableUsedError(Error):
title: ClassVar[str] = "Variable modified in modifier block"
span_label: ClassVar[str] = (
"Cannot use `{place}` because it was modified inside a modifier block"
)
place: Place

@dataclass(frozen=True)
class ModifiedHere(Note):
span_label: ClassVar[str] = "`{place}` modified here"
place: Place

@dataclass(frozen=True)
class Explanation(Help):
message: ClassVar[str] = (
"Modifications inside modifier blocks are not reflected outside the block"
)


@dataclass(frozen=True)
class ComprAlreadyUsedError(Error):
title: ClassVar[str] = "Copy violation"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ComprAlreadyUsedError,
DropAfterCallError,
InCallArg,
ModifiedVariableUsedError,
MoveOutOfSubscriptError,
NonCopyableCaptureError,
NonCopyablePartialApplyError,
Expand Down Expand Up @@ -101,6 +102,9 @@ class UseKind(Enum):
#: An owned value is renamed or stored in a tuple/list
MOVE = auto()

#: A captured value is assigned inside a modifier block
DEFINED_IN_MODIFIER = auto()

@property
def indicative(self) -> str:
"""Describes a use in an indicative mood.
Expand All @@ -126,6 +130,8 @@ def subjunctive(self) -> str:
return "returned"
case UseKind.MOVE:
return "moved"
case UseKind.DEFINED_IN_MODIFIER:
return "modified"


class Use(NamedTuple):
Expand Down Expand Up @@ -298,9 +304,11 @@ def visit_PlaceNode(
self.visit(subscript.getitem_call)
# For all other places, we record uses of all leaves
else:
# Check each leaf separately so we catch partial moves, e.g. struct fields
for place in leaf_places(node.place):
x = place.id
if (prev_use := self.scope.used(x)) and not place.ty.copyable:
prev_use = self.scope.used(x)
if prev_use and not place.ty.copyable:
# When the user's expression (node.place) differs from the
# conflicting leaf (place), report the error about the parent
# and explain which child was already moved.
Expand All @@ -319,6 +327,15 @@ def visit_PlaceNode(
if has_explicit_copy(place.ty):
err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None))
raise GuppyError(err)
# A modifier-block assignment makes the original binding stale, so
# any subsequent use in the same scope is rejected.
if prev_use and prev_use.kind == UseKind.DEFINED_IN_MODIFIER:
err = ModifiedVariableUsedError(node, place)
err.add_sub_diagnostic(
ModifiedVariableUsedError.ModifiedHere(prev_use.node, place)
)
err.add_sub_diagnostic(ModifiedVariableUsedError.Explanation(None))
raise GuppyError(err)
self.scope.use(x, node, use_kind)

def visit_Assign(self, node: ast.Assign) -> None:
Expand Down Expand Up @@ -685,7 +702,7 @@ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
# body(q1, q2, ...)
# ```

# check control
# Check control
for ctrl in node.control:
for arg in ctrl.ctrl:
if isinstance(arg, PlaceNode):
Expand All @@ -696,7 +713,7 @@ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
unnamed_err.add_sub_diagnostic(UnnamedExprNotUsedError.Fix(None))
raise GuppyTypeError(unnamed_err)

# check power
# Check power
for power in node.power:
if isinstance(power.iter, PlaceNode):
self.visit_PlaceNode(
Expand All @@ -705,22 +722,34 @@ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
else:
self.visit(power.iter)

# check captured variables
# Check captured variables:
# We check that the modifier is not using consumed variables or copyable
# variables already used by other modifiers
for var, use in node.captured.values():
for place in leaf_places(var):
use_kind = (
UseKind.BORROW if InputFlags.Inout in var.flags else UseKind.CONSUME
)

x = place.id
if (prev_use := self.scope.used(x)) and not place.ty.copyable:
used_err = AlreadyUsedError(use, place, use_kind)
used_err.add_sub_diagnostic(
AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind)
)
if has_explicit_copy(place.ty):
used_err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None))
raise GuppyError(used_err)
if prev_use := self.scope.used(x):
if not place.ty.copyable:
used_err = AlreadyUsedError(use, place, use_kind)
used_err.add_sub_diagnostic(
AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind)
)
if has_explicit_copy(place.ty):
used_err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None))
raise GuppyError(used_err)
if prev_use.kind == UseKind.DEFINED_IN_MODIFIER:
err = ModifiedVariableUsedError(use, place)
err.add_sub_diagnostic(
ModifiedVariableUsedError.ModifiedHere(prev_use.node, place)
)
err.add_sub_diagnostic(
ModifiedVariableUsedError.Explanation(None)
)
raise GuppyError(err)

self.scope.use(x, node, use_kind)

# reassign controls
Expand All @@ -734,6 +763,12 @@ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
if InputFlags.Inout in var.flags:
self._reassign_single_inout_arg(var, var.defined_at or use)

for name, assignment in node.modified_captured.items():
var, _ = node.captured[name]
if var.ty.copyable:
for place in leaf_places(var):
self.scope.use(place.id, assignment, UseKind.DEFINED_IN_MODIFIER)


def leaf_places(place: Place) -> Iterator[Place]:
"""Returns all leaf descendant projections of a place."""
Expand Down Expand Up @@ -889,36 +924,51 @@ def check_cfg_linearity(
for bb, scope in scopes.items():
live_before_bb = live_before[bb]

# We have to check that used not copyable variables are not being outputted
# Check that values made unusable in this block are not live in successors.
# This catches both non-copyable values that were already used and copyable
# captures that were re-assigned inside a modifier block.
for succ in bb.successors:
live = live_before[succ]
for x, use_bb in live.items():
use_scope = scopes[use_bb]
place = use_scope[x]
if not place.ty.copyable and (prev_use := scope.used(x)):
use = use_scope.used_parent[x]
# Special case if this is a use arising from the implicit returning
# of a borrowed argument
if isinstance(use.node, InoutReturnSentinel):
assert isinstance(use.node.var, Variable)
assert InputFlags.Inout in use.node.var.flags
err: Error = BorrowSubPlaceUsedError(
use.node.var.defined_at, use.node.var, place
if prev_use := scope.used(x):
# first we check for variable non-copyable variable
if not place.ty.copyable:
use = use_scope.used_parent[x]
# Special case if this is a use arising from the implicit
# returning of a borrowed argument
if isinstance(use.node, InoutReturnSentinel):
assert isinstance(use.node.var, Variable)
assert InputFlags.Inout in use.node.var.flags
err: Error = BorrowSubPlaceUsedError(
use.node.var.defined_at, use.node.var, place
)
err.add_sub_diagnostic(
BorrowSubPlaceUsedError.PrevUse(
prev_use.node, prev_use.kind
)
)
err.add_sub_diagnostic(BorrowSubPlaceUsedError.Fix(None))
raise GuppyError(err)
err = AlreadyUsedError(use.node, place, use.kind)
err.add_sub_diagnostic(
AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind)
)
if has_explicit_copy(place.ty):
err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None))
raise GuppyError(err)
# then we check for variable used inside the modifier block
if prev_use.kind == UseKind.DEFINED_IN_MODIFIER:
use = use_scope.used_parent[x]
err = ModifiedVariableUsedError(use.node, place)
err.add_sub_diagnostic(
BorrowSubPlaceUsedError.PrevUse(
prev_use.node, prev_use.kind
)
ModifiedVariableUsedError.ModifiedHere(prev_use.node, place)
)
err.add_sub_diagnostic(
ModifiedVariableUsedError.Explanation(None)
)
err.add_sub_diagnostic(BorrowSubPlaceUsedError.Fix(None))
raise GuppyError(err)
err = AlreadyUsedError(use.node, place, use.kind)
err.add_sub_diagnostic(
AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind)
)
if has_explicit_copy(place.ty):
err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None))
raise GuppyError(err)

# On the other hand, unused variables that are not droppable *must* be outputted
for place in scope.values():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Type checking code for modifiers."""

import ast
from collections.abc import Collection

from guppylang_internals.ast_util import loop_in_ast, with_loc
from guppylang_internals.ast_util import AstNode, loop_in_ast, with_loc
from guppylang_internals.cfg.bb import BB
from guppylang_internals.checker.cfg_checker import check_cfg
from guppylang_internals.checker.core import Context, Variable
Expand Down Expand Up @@ -36,6 +37,7 @@ def check_modified_block(
for x, using_bb in cfg.live_before[cfg.entry_bb].items()
if x in ctx.locals
}
modified_captured = _modified_captured_vars(modified_block, captured.keys())

# We do not allow any assignments if it is daggered.
if modified_block.has_dagger():
Expand Down Expand Up @@ -83,12 +85,37 @@ def check_modified_block(
checked_cfg,
func_ty,
captured,
modified_captured,
modified_block.modifiers,
**dict(ast.iter_fields(modified_block)),
)
return with_loc(modified_block, checked_modifier)


def _modified_captured_vars(
modified_block: ModifiedBlock, captured_names: Collection[str]
) -> dict[str, AstNode]:
"""Find captured variables assigned anywhere in a modifier body."""
modified = {}
for body_bb in modified_block.cfg.bbs:
modified.update(
{
x: assignment
for x, assignment in body_bb.vars.assigned.items()
if x in captured_names
}
)
for stmt in body_bb.statements:
assert not isinstance(stmt, CheckedModifiedBlock), (
"CheckedModifiedBlocks should not be present while checking the cfg"
)
if isinstance(stmt, ModifiedBlock):
# Recursively check nested modified blocks
modified.update(_modified_captured_vars(stmt, captured_names))

return modified


def _set_inout_if_non_copyable(var: Variable) -> Variable:
"""Set the `inout` flag if the variable is non-copyable."""
if not var.ty.copyable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,10 @@ def visit_ModifiedBlock(self, node: ModifiedBlock) -> ast.stmt:
raise InternalGuppyError("BB required to check with block!")

# check the body of the modified block
modified_block = check_modified_block(node, self.bb, self.ctx)
checked_modified_block = check_modified_block(node, self.bb, self.ctx)

# check the arguments of the control and power.
for control in modified_block.control:
for control in checked_modified_block.control:
ctrl = control.ctrl
# This case is handled during CFG construction.
assert len(ctrl) > 0
Expand All @@ -447,13 +447,13 @@ def visit_ModifiedBlock(self, node: ModifiedBlock) -> ast.stmt:
assert len(subst) == 0
control.qubit_num = len(ctrl)

for power in node.power:
for power in checked_modified_block.power:
power.iter, subst = self._check_expr(
power.iter, NumericType(NumericType.Kind.Nat)
)
assert len(subst) == 0

return modified_block
return checked_modified_block

def visit_If(self, node: ast.If) -> None:
raise InternalGuppyError("Control-flow statement should not be present here.")
Expand Down
4 changes: 4 additions & 0 deletions guppylang-internals/src/guppylang_internals/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,13 +887,16 @@ class CheckedModifiedBlock(ast.With):
ty: FunctionType
#: Mapping from names to variables captured in the body.
captured: Mapping[str, tuple["Variable", AstNode]]
#: Mapping from captured names to assignments inside the body.
modified_captured: Mapping[str, AstNode]

def __init__(
self,
def_id: "DefId",
cfg: "CheckedCFG[Place]",
ty: FunctionType,
captured: Mapping[str, tuple["Variable", AstNode]],
modified_captured: Mapping[str, AstNode],
modifiers: Modifiers,
*args: Any,
**kwargs: Any,
Expand All @@ -903,6 +906,7 @@ def __init__(
self.cfg = cfg
self.ty = ty
self.captured = captured
self.modified_captured = modified_captured
self.modifiers = modifiers

@property
Expand Down
8 changes: 8 additions & 0 deletions tests/error/errors_on_usage/branch_in_modifier.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Variable not defined (at $FILE:10:8)
|
8 | x = 3
9 | a = b
10 | c = x +1
| ^ `x` is not defined

Guppy compilation failed due to 1 previous error
14 changes: 14 additions & 0 deletions tests/error/errors_on_usage/branch_in_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from guppylang.decorator import guppy
from guppylang.std.builtins import power

@guppy
def test(b: bool) -> int:
with power(2):
if b:
x = 3
a = b
c = x +1
return 0


test.check()
17 changes: 17 additions & 0 deletions tests/error/modifier_errors/captured_classical_modified.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Error: Variable modified in modifier block (at $FILE:10:11)
|
8 | with power(2):
9 | x += 1
10 | return x
| ^ Cannot use `x` because it was modified inside a modifier
| block

Note:
|
8 | with power(2):
9 | x += 1
| ------ `x` modified here

Help: Modifications inside modifier blocks are not reflected outside the block

Guppy compilation failed due to 1 previous error
Loading
Loading