Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
03165ca
Add protocol definition parsing
tatiana-s Sep 22, 2025
52ce8eb
Merge remote-tracking branch 'origin/main' into parsing
tatiana-s Sep 22, 2025
6a2cad0
Format
tatiana-s Sep 24, 2025
1171ad2
Rename test file to stop 3.10 checks from failing
tatiana-s Sep 25, 2025
fa4d04c
Move `get_instance_func`
tatiana-s Sep 25, 2025
60ca51b
Merge remote-tracking branch 'origin/main' into parsing
tatiana-s Nov 10, 2025
0617428
Merge remote-tracking branch 'origin/main' into parsing
tatiana-s Nov 17, 2025
032d30b
Address some comments
tatiana-s Nov 17, 2025
2d3fc04
delete py312 error folder for now
tatiana-s Nov 17, 2025
9744471
Remove declaration annotation
tatiana-s Nov 24, 2025
d971403
Add bounds parsing
tatiana-s Dec 8, 2025
41b8f47
Merge branch 'protocol-main' into parsing
tatiana-s Dec 8, 2025
322febd
Merge branch 'protocol-main' into parsing
tatiana-s Dec 9, 2025
02c1e12
Merge remote-tracking branch 'origin/main' into ts/move-to-engine
tatiana-s Dec 9, 2025
12564f0
Merge branch 'ts/move-to-engine' into checking-fresh
tatiana-s Dec 9, 2025
dc6aa9c
Copy initial protocol checker and call it during checking
tatiana-s Dec 9, 2025
a427465
Merge remote-tracking branch 'origin/protocol-main' into parsing
tatiana-s Dec 17, 2025
0607ec2
Require explicit params + minor fixes
tatiana-s Dec 17, 2025
5f88618
Merge branch 'parsing' into checking-fresh
tatiana-s Dec 17, 2025
7291013
Continue checking
tatiana-s Dec 17, 2025
7ac69a0
Start protocol call checking, fix parameterised test syntax, fix some…
tatiana-s Jan 5, 2026
7784670
Add assumption test, tuples instead of lists for hashing, debug proto…
tatiana-s Jan 8, 2026
66e1508
More protocol call checking
tatiana-s Jan 12, 2026
05831bb
Merge remote-tracking branch 'origin/main' into checking-fresh
croyzor Jan 14, 2026
441ca2d
Continue trying to get checking to work
tatiana-s Jan 16, 2026
85d1b0c
Continue protocol checking
tatiana-s Jan 19, 2026
ce63fbc
Add more examples
tatiana-s Jan 21, 2026
9ab4e5c
Fix mypy errors
croyzor Jan 27, 2026
6fa94d4
Delete redundant post_init restriction
croyzor Jan 27, 2026
920f5bb
Add missing transformation functions
croyzor Jan 27, 2026
c326940
fixups
croyzor Mar 17, 2026
9807b0f
Get protocol example working
croyzor Mar 18, 2026
8faba54
Drop bogus assertion
croyzor Mar 18, 2026
aba8099
fix: Return any from ExistentialVar.transform
croyzor Mar 18, 2026
7361cd6
Fixups
croyzor Mar 18, 2026
c312c86
Merge branch 'protocol-main' into checking-fresh
croyzor Mar 20, 2026
f43c8fa
Fixups
croyzor Mar 20, 2026
1ed193c
Set protocol tests to check, not compile
croyzor Mar 20, 2026
a2dfe38
Make must-implement-protocol fields tuples so they can be hashed
croyzor Mar 23, 2026
feb97c0
refactor: Drive-by variable name change
croyzor Mar 24, 2026
72819f5
fix: Unify the types of const args
croyzor Mar 24, 2026
1111460
Make unquantified unquantify more
croyzor Mar 25, 2026
00784cb
Add transform method to ExistentialVar
croyzor Mar 25, 2026
4d5da77
Transform the type of bound consts
croyzor Mar 25, 2026
f421f45
Add `compare=False` for ExistentialConstVar types
croyzor Apr 7, 2026
ebce763
Stuff that seems sensible but not so sure what it does
croyzor Mar 25, 2026
90196f8
Steal from cr/better-inst PR
croyzor Apr 8, 2026
9989a03
Update snapshots
croyzor Apr 8, 2026
7f5e632
Merge branch 'protocol-main' into checking-fresh
croyzor Apr 9, 2026
9674b68
Merge remote-tracking branch 'origin/cr/better-inst' into checking-fresh
croyzor Apr 9, 2026
77c28a4
Merge branch 'protocol-main' into checking-fresh
croyzor Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import copy
import sys
import traceback
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from contextlib import suppress
from dataclasses import replace
from types import ModuleType
Expand Down Expand Up @@ -119,13 +119,14 @@
MakeIter,
PartialApply,
PlaceNode,
ProtocolCall,
SubscriptAccessAndDrop,
TensorCall,
TupleAccessAndDrop,
TypeApply,
)
from guppylang_internals.span import Span, to_span
from guppylang_internals.tys.arg import ConstArg, TypeArg
from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg
from guppylang_internals.tys.builtin import (
bool_type,
float_type,
Expand All @@ -147,10 +148,16 @@
ConstValue,
ExistentialConstVar,
)
from guppylang_internals.tys.param import TypeParam, check_all_args
from guppylang_internals.tys.param import (
ConstParam,
Parameter,
TypeParam,
check_all_args,
)
from guppylang_internals.tys.parsing import arg_from_ast
from guppylang_internals.tys.subst import Inst, Subst
from guppylang_internals.tys.ty import (
BoundTypeVar,
EnumType,
ExistentialTypeVar,
FuncInput,
Expand Down Expand Up @@ -323,6 +330,23 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]:
if isinstance(defn, CallableDef):
return defn.check_call(node.args, ty, node, self.ctx)

from guppylang_internals.definition.protocol import ParsedProtocolDef

# Protocol methods don't have their own definition, we have to look up the
# protocol definition itself first.
if isinstance(defn, ParsedProtocolDef):
assert isinstance(func_ty, FunctionType)
args, subst, inst = check_call(func_ty, node.args, ty, node, self.ctx)
return with_loc(
node,
ProtocolCall(
member=node.func.id,
proto_id=node.func.def_id,
args=args,
type_args=inst,
),
), subst

# When calling a `PartialApply` node, we just move the args into this call
if isinstance(node.func, PartialApply):
node.args = [*node.func.args, *node.args]
Expand Down Expand Up @@ -605,6 +629,28 @@ def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
# you loose access to all fields besides `a`).
expr = FieldAccessAndDrop(value=node.value, struct_ty=ty, field=field)
return with_loc(node, expr), field.ty
elif isinstance(ty, BoundTypeVar):
from guppylang_internals.definition.protocol import CheckedProtocolDef

for proto in ty.implements:
proto_def = ENGINE.get_checked(proto.def_id, proto.type_args)
assert isinstance(proto_def, CheckedProtocolDef)
for member_name, member_ty in proto_def.members.items():
if node.attr == member_name:
name_node = with_type(
member_ty,
with_loc(
node,
# TODO: Should we have a different AST node for this?
GlobalName(id=member_name, def_id=proto.def_id),
),
)
ty_without_self = FunctionType(
member_ty.inputs[1:], member_ty.output, member_ty.params
)
return with_loc(
node, PartialApply(func=name_node, args=[node.value])
), ty_without_self
elif isinstance(ty, EnumType):
if node.attr in ty.variants_as_dict:
# If we are accessing to a variant, we need to check that node.value is
Expand Down Expand Up @@ -870,8 +916,26 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:
if isinstance(node.func, GlobalName):
defn = self.ctx.globals[node.func.def_id]
if isinstance(defn, CallableDef):
# TODO: Should we error here if not callable?
return defn.synthesize_call(node.args, node, self.ctx)

from guppylang_internals.definition.protocol import ParsedProtocolDef

# Protocol methods don't have their own definition, we have to look up the
# protocol definition itself first.
if isinstance(defn, ParsedProtocolDef):
assert isinstance(ty, FunctionType)
args, return_ty, inst = synthesize_call(ty, node.args, node, self.ctx)
return with_loc(
node,
ProtocolCall(
member=node.func.id,
proto_id=node.func.def_id,
args=args,
type_args=inst,
),
), return_ty

# When calling a `PartialApply` node, we just move the args into this call
if isinstance(node.func, PartialApply):
node.args = [*node.func.args, *node.args]
Expand Down Expand Up @@ -1112,6 +1176,7 @@ def type_check_args(
inputs: list[ast.expr],
func_ty: FunctionType,
subst: Subst,
free_var_mapping: Mapping[ExistentialVar, Parameter],
ctx: Context,
node: AstNode,
) -> tuple[list[ast.expr], Subst]:
Expand All @@ -1127,6 +1192,21 @@ def type_check_args(
comptime_args = iter(func_ty.comptime_args)
for inp, func_inp in zip(inputs, func_ty.inputs, strict=True):
a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument")
# For each new substitution we find for any previously uninstantiated parameter,
# we check it in order to possibly infer more substitutions through protocol
# checking.
for var in s:
if var in free_var_mapping:
param = free_var_mapping[var]
match param, s[var].to_arg():
case TypeParam(), TypeArg() as arg:
check_arg, check_subst = param.check_arg(arg, a)
subst |= check_subst
subst[var] = check_arg.ty
case ConstParam(), ConstArg() as arg:
subst[var] = param.check_arg(arg, a).const
case _:
raise Exception("Bad kinding")
subst |= s
if InputFlags.Inout in func_inp.flags and isinstance(a, PlaceNode):
a.place = check_place_assignable(
Expand Down Expand Up @@ -1264,7 +1344,16 @@ def synthesize_call(
# Replace quantified variables with free unification variables and try to infer an
# instantiation by checking the arguments
unquantified, free_vars = func_ty.unquantified()
args, subst = type_check_args(args, unquantified, {}, ctx, node)
var_mapping = {}
inst: list[Argument | None] = [None for _ in free_vars]
for ix, (var, param) in enumerate(zip(free_vars, func_ty.params, strict=True)):
var_mapping[var] = param.instantiate_bounds(inst)
if isinstance(var, ExistentialTypeVar):
inst[ix] = TypeArg(var)
elif isinstance(var, ExistentialConstVar):
inst[ix] = ConstArg(var)

args, subst = type_check_args(args, unquantified, {}, var_mapping, ctx, node)

# Success implies that the substitution is closed
assert all(not t.unsolved_vars for t in subst.values())
Expand Down Expand Up @@ -1339,7 +1428,8 @@ def check_call(
raise GuppyTypeError(TypeMismatchError(node, ty, unquantified.output, kind))

# Try to infer more by checking against the arguments
inputs, subst = type_check_args(inputs, unquantified, subst, ctx, node)
# CR TODO: Make the var mapping here as in synthesize
inputs, subst = type_check_args(inputs, unquantified, subst, {}, ctx, node)

# Also make sure we found an instantiation for all free vars in the type we're
# checking against
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
UnnamedTupleNotUsedError,
)
from guppylang_internals.definition.custom import CustomFunctionDef
from guppylang_internals.definition.protocol import CheckedProtocolDef
from guppylang_internals.definition.value import CallableDef
from guppylang_internals.engine import DEF_STORE, ENGINE
from guppylang_internals.error import GuppyError, GuppyTypeError
Expand All @@ -64,6 +65,7 @@
LocalCall,
PartialApply,
PlaceNode,
ProtocolCall,
StateResultExpr,
SubscriptAccessAndDrop,
TensorCall,
Expand Down Expand Up @@ -448,6 +450,13 @@ def visit_LocalCall(self, node: LocalCall) -> None:
self._visit_call_args(func_ty, node)
self._reassign_inout_args(func_ty, node)

def visit_ProtocolCall(self, node: ProtocolCall) -> None:
proto = ENGINE.get_checked(node.proto_id, node.type_args)
assert isinstance(proto, CheckedProtocolDef)
func_ty = proto.members[node.member].instantiate(node.type_args)
self._visit_call_args(func_ty, node)
self._reassign_inout_args(func_ty, node)

def visit_TensorCall(self, node: TensorCall) -> None:
for arg in node.args:
self.visit(arg)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from abc import ABC
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import ClassVar, TypeAlias

from guppylang_internals.definition.common import DefId
from guppylang_internals.definition.protocol import CheckedProtocolDef
from guppylang_internals.diagnostic import Error
from guppylang_internals.engine import ENGINE
from guppylang_internals.error import GuppyError
from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg
from guppylang_internals.tys.const import BoundConstVar, ExistentialConstVar
from guppylang_internals.tys.protocol import ProtocolInst
from guppylang_internals.tys.subst import Inst, Subst, Substituter
from guppylang_internals.tys.ty import (
BoundTypeVar,
ExistentialTypeVar,
FunctionType,
Type,
unify,
unify_type_args,
)
from guppylang_internals.tys.var import ExistentialVar


@dataclass(frozen=True)
class ImplProofBase(ABC):
proto: ProtocolInst
ty: Type

def __post_init__(self) -> None:
assert all(not arg.unsolved_vars for arg in self.proto.type_args)


@dataclass(frozen=True)
class ConcreteImplProof(ImplProofBase):
#: For each protocol member, the concrete function that implements it together with
#: an instantiation of the type parameters of the implementation. This could refer
#: to bound variables specified by the protocol method.
#: If we have a protocol `def foo[T](x: T, y: int)` and implementation
#: `def foo[A, B](x: A, y: B)`, then the instantiation will specify `A := T` and
#: `B := int`.
member_impls: Mapping[str, tuple[DefId, Inst]]

# def __post_init__(self) -> None:
# assert self.member_impls.keys() == self.proto.members.keys()


@dataclass(frozen=True)
class AssumptionImplProof(ImplProofBase):
ty: BoundTypeVar

def __post_init__(self) -> None:
super().__post_init__()
# assert self.proto in self.ty.implements


ImplProof: TypeAlias = ConcreteImplProof | AssumptionImplProof


@dataclass(frozen=True)
class ProtocolMemberMissing(Error):
title: ClassVar[str] = "Protocol member implementation missing"
span_label: ClassVar[str] = (
"Type {impl_name} does not implement member {member_name} of "
"protocol {proto_name}"
)
impl_name: str
proto_name: str
member_name: str


def _unify_args(
xs: Sequence[ExistentialVar], ys: Sequence[Argument], subst: Subst | None
) -> Subst | None:
for x, y in zip(xs, ys, strict=True):
## CR TODO: What are we doing about potential must_implement protocols in `ty`?
match x, y:
case ExistentialTypeVar(), TypeArg(ty=ty):
subst = unify(x, ty, subst)
case ExistentialConstVar(), ConstArg(const=const):
subst = unify(x, const, subst)
case _:
# CR: Is this always an InternalGuppyError?
raise Exception("Const vs Type arg for protocol")
return subst


def _instantiate_self(
proto_func: FunctionType, proto_inst: ProtocolInst, impl_ty: Type
) -> FunctionType:
# Assumption: first argument must be self.
self_ty = proto_func.inputs[0].ty
assert isinstance(self_ty, BoundTypeVar)
[bound] = self_ty.implements
assert bound.def_id == proto_inst.def_id
# A mutable PartialInst
partial_inst: list[Argument | None] = [None for _ in proto_func.params]
# Instantiate all self type occurrences in protocol methods with the type we assume
# is implementing the protocol.
for proto_arg, bound_arg in zip(proto_inst.type_args, bound.type_args, strict=True):
match bound_arg:
case TypeArg(ty=BoundTypeVar(idx=idx)):
partial_inst[idx] = proto_arg
case ConstArg(const=BoundConstVar(idx=idx)):
partial_inst[idx] = proto_arg
partial_inst[self_ty.idx] = impl_ty.to_arg()
return proto_func.instantiate_partial(partial_inst)


def check_protocol(ty: Type, protocol: ProtocolInst) -> tuple[ImplProof, Subst]:
"""Check that `ty` implements `protocol`"""

# Invariant: `ty` and `protocol` might have unsolved variables.
protocol_def = ENGINE.get_checked(protocol.def_id, protocol.type_args)
assert isinstance(protocol_def, CheckedProtocolDef)

# If `ty` is a bound type variable, we try to handle the case
# `def foo[T, MyProto: Proto[T]](MyProto, ...) -> ...`
# ... we must assume that bound variable `MyProto` implements `Proto[T]`
# when `check_protocol` is invoked for this definition.
if isinstance(ty, BoundTypeVar):
# Iterate over all of the "must implement" bounds for `ty`, and collect
# the ones that result in an implementation of `protocol`.
# We hope there's only one answer!
candidates: list[tuple[ProtocolInst, Subst]] = []
for impl in ty.implements:
if impl.def_id == protocol.def_id:
# TODO: Is this correct?
# Does it break if we have protocols with other proto args
subst = unify_type_args(protocol.type_args, impl.type_args, {})
if subst is not None:
candidates.append((impl, subst))
if len(candidates) == 0:
raise Exception("Zero")
elif len(candidates) > 1:
raise Exception("more than one")
[(_, subst)] = candidates
new_ty = ty.substitute(subst)
assert isinstance(new_ty, BoundTypeVar)
return AssumptionImplProof(
protocol.transform(Substituter(subst)),
# CR: Should we return `new_ty` (substituted) or `ty`?
new_ty,
), subst

subst = {}
member_impls: dict[str, tuple[DefId, Inst]] = {}
for name, proto_sig in protocol_def.members.items():
assert isinstance(proto_sig, FunctionType)
# Partially instantiate proto_sig with `protocol.type_args` and `ty`.
proto_sig = _instantiate_self(proto_sig, protocol, ty)
func = ENGINE.get_instance_func(ty, name)
if not func:
raise GuppyError(
ProtocolMemberMissing(
ENGINE.get_parsed(
protocol.def_id
).defined_at, # CR: Dummy ast location, we can do better
impl_name=str(ty),
proto_name=protocol_def.name,
member_name=name,
)
)
# Make type variables in implementation signature existential for unification.
impl_sig, ex_impl_vars = func.ty.unquantified()
# Make parameters in protocol signature unbound for unification.
proto_sig = FunctionType(proto_sig.inputs, proto_sig.output, params=[])
# Try to unify both signatures.
subst = unify(proto_sig, impl_sig, subst)
if subst is None:
raise Exception("Signature Mismatch")
if any(x not in subst for x in ex_impl_vars):
raise Exception("Unresolved variables in implementation")
# Turn these into type vars
impl_vars: Inst = tuple(subst[var].to_arg() for var in ex_impl_vars)
member_impls[name] = func.id, impl_vars

if any(x not in subst for arg in protocol.type_args for x in arg.unsolved_vars):
raise Exception("Couldn't figure out variables in protocol")
subst = {x: subst[x] for arg in protocol.type_args for x in arg.unsolved_vars}
return ConcreteImplProof(
protocol.transform(Substituter(subst)), ty, member_impls
), subst
Loading
Loading