Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions guppylang-internals/src/guppylang_internals/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,12 +395,34 @@ def check(self, def_ids: list[DefId], *, reset: bool = True) -> None:
if reset:
self.reset()

# We allow generic functions as checking entrypoints as long as we don't run
# into a check that requires monomorphization. For this, we check a version
# where all parameters are instantiated to opaque `BoundVariable`s.
for def_id in def_ids:
entry_defn = self.get_parsed(def_id)
check_entry_point_non_generic(entry_defn)
entry_mono_args: Inst = ()
self.to_check_worklist[def_id, entry_mono_args] = entry_defn
entry_params = (
entry_defn.params if isinstance(entry_defn, CheckableGenericDef) else []
)
entry_mono_args = tuple(param.to_bound() for param in entry_params)
try:
self.checked[def_id, entry_mono_args] = self.get_checked(
def_id, entry_mono_args
)
except RequiresMonomorphizationError:
# `RequiresMonomorphizationError` is raised whenever we cannot proceed
# checking without having the monomorphization available. In that case,
# we give up and prompt the user to specify the generic arguments.
assert isinstance(entry_defn, CheckableGenericDef)
description = (
f"{entry_defn.description.capitalize()} `{entry_defn.name}`"
)
err = EntryCheckMonomorphizeError(
entry_defn.defined_at, description, entry_defn.params
)
raise GuppyError(err) from None

# Checking the entrypoint will have populated the worklist, so now we need to
# process it
while (
self.types_to_check_worklist
or self.generic_to_check_worklist
Expand Down Expand Up @@ -566,6 +588,29 @@ def params_str(self) -> str:
return ", ".join(f"`{p.name}`" for p in self.params)


@dataclass(frozen=True)
class EntryCheckMonomorphizeError(Error):
title: ClassVar[str] = "Invalid check point"
span_label: ClassVar[str] = (
"{thing} can only be checked if the value{plural_s} of its generic "
"parameter{plural_s} {params_str} {is_are} known"
)
thing: str
params: Sequence[Parameter]

@property
def plural_s(self) -> str:
return "s" if len(self.params) > 1 else ""

@property
def is_are(self) -> str:
return "are" if len(self.params) > 1 else "is"

@property
def params_str(self) -> str:
return ", ".join(f"`{p.name}`" for p in self.params)


def check_entry_point_non_generic(defn: ParsedDef) -> None:
"""Checks if the given definition is a valid compilation entry-point.

Expand Down
Loading