diff --git a/guppylang-internals/src/guppylang_internals/diagnostic.py b/guppylang-internals/src/guppylang_internals/diagnostic.py index 37048ba87..a5c3927dc 100644 --- a/guppylang-internals/src/guppylang_internals/diagnostic.py +++ b/guppylang-internals/src/guppylang_internals/diagnostic.py @@ -174,6 +174,14 @@ class Error(Diagnostic, Protocol): level: ClassVar[Literal[DiagnosticLevel.ERROR]] = DiagnosticLevel.ERROR +@runtime_checkable +@dataclass(frozen=True) +class Warning(Diagnostic, Protocol): + """Compiler diagnostic for non-fatal warnings.""" + + level: ClassVar[Literal[DiagnosticLevel.WARNING]] = DiagnosticLevel.WARNING + + @runtime_checkable @dataclass(frozen=True) class Note(SubDiagnostic, Protocol): diff --git a/guppylang-internals/src/guppylang_internals/error.py b/guppylang-internals/src/guppylang_internals/error.py index 59984e398..000a35f35 100644 --- a/guppylang-internals/src/guppylang_internals/error.py +++ b/guppylang-internals/src/guppylang_internals/error.py @@ -106,7 +106,11 @@ def saved_exception_hook() -> Iterator[None]: def pretty_errors(f: FuncT) -> FuncT: - """Decorator to print custom error banners when a `GuppyError` occurs.""" + """Decorator to print custom error banners when a `GuppyError` occurs. + + This is also the standard boundary for warning collection on top-level engine + operations: wrapped calls participate in one `diagnostic_report()` session. + """ def hook( excty: type[BaseException], err: BaseException, traceback: TracebackType | None @@ -127,7 +131,9 @@ def hook( @functools.wraps(f) def pretty_errors_wrapped(*args: Any, **kwargs: Any) -> Any: - with exception_hook(hook): + from guppylang_internals.warning import diagnostic_report + + with diagnostic_report(), exception_hook(hook): return f(*args, **kwargs) return cast("FuncT", pretty_errors_wrapped) diff --git a/guppylang-internals/src/guppylang_internals/warning.py b/guppylang-internals/src/guppylang_internals/warning.py new file mode 100644 index 000000000..495ce3fee --- /dev/null +++ b/guppylang-internals/src/guppylang_internals/warning.py @@ -0,0 +1,221 @@ +import sys +import warnings +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, NamedTuple + +from guppylang_internals.error import InternalGuppyError + +if TYPE_CHECKING: + from guppylang_internals.diagnostic import Diagnostic + + +class GuppyWarning(UserWarning): + """Warning category for non-fatal compiler diagnostics.""" + + +class _WarningKey(NamedTuple): + """Stable identity for deduplicating warnings within one operation.""" + + # File path passed through to Python's warning machinery, if available. + filename: str | None + # 1-based source line passed through to Python's warning machinery, if available. + lineno: int | None + # 0-based source column used only for deduplicating distinct warnings on one line. + column: int | None + # Concise warning text emitted through Python's warning machinery. + message: str + + +@dataclass(frozen=True) +class PendingWarning: + """Buffered warning waiting to be emitted at the end of a top-level operation.""" + + # Original structured diagnostic used for rich rendering. + diagnostic: "Diagnostic" + # Stable warning identity and Python-warning payload. + _key: _WarningKey + + @property + def message(self) -> str: + """Concise warning text emitted through Python's warning machinery.""" + return self._key.message + + @property + def filename(self) -> str | None: + """Source file reported to Python's warning machinery, if available.""" + return self._key.filename + + @property + def lineno(self) -> int | None: + """1-based source line reported to Python's warning machinery, if available.""" + return self._key.lineno + + @property + def column(self) -> int | None: + """0-based source column used for deduplication within one operation.""" + return self._key.column + + +@dataclass +class DiagnosticSession: + """Per-operation diagnostic state shared across nested compiler calls.""" + + rich_warnings: bool = False + pending_warnings: list[PendingWarning] = field(default_factory=list) + seen_warnings: set[_WarningKey] = field(default_factory=set) + + +_DIAGNOSTIC_SESSION: ContextVar[DiagnosticSession | None] = ContextVar( + "_DIAGNOSTIC_SESSION", default=None +) +_RICH_WARNINGS: ContextVar[bool] = ContextVar("_RICH_WARNINGS", default=False) + + +@contextmanager +def rich_warnings() -> Iterator[None]: + """Enable rich stderr rendering for compiler warnings within the current scope.""" + + token = _RICH_WARNINGS.set(True) + try: + yield + finally: + _RICH_WARNINGS.reset(token) + + +@contextmanager +def diagnostic_report() -> Iterator[None]: + """Collects compiler warnings and flushes them once per top-level operation.""" + + session = _DIAGNOSTIC_SESSION.get() + # Nested compiler entrypoints reuse the same session so one user operation only + # flushes once, at the outermost boundary. + outermost = session is None + token = None + if outermost: + session = DiagnosticSession(rich_warnings=_RICH_WARNINGS.get()) + token = _DIAGNOSTIC_SESSION.set(session) + assert session is not None + + try: + yield + except Exception: + if outermost: + # Failed operations should not emit queued warnings. Clear eagerly so the + # exception path behaves the same whether the warning producer ran before + # or after the eventual failure. + session.pending_warnings.clear() + session.seen_warnings.clear() + raise + else: + if outermost: + # Only the outermost context flushes to Python warnings. Inner contexts + # merely contribute to the shared session. + for pending_warning in session.pending_warnings: + _emit_pending_warning(pending_warning) + finally: + if outermost and token is not None: + # Restore the previous ContextVar value even if warning emission itself + # raises, so subsequent compiler operations start with a clean session. + _DIAGNOSTIC_SESSION.reset(token) + + +def emit_warning(diag: "Diagnostic") -> None: + """Queue or emit a non-fatal compiler warning.""" + + pending_warning = _pending_warning(diag) + session = _DIAGNOSTIC_SESSION.get() + if session is None: + # Warnings emitted outside a diagnostic_report block still surface + # immediately; the session machinery is only needed for batching and + # deduplicating within top-level compiler operations. + _emit_pending_warning(pending_warning) + return + + if pending_warning._key in session.seen_warnings: + # Re-emitting the same warning from nested passes or revisited CFG nodes should + # not duplicate the user-facing Python warning within one operation. + return + + session.seen_warnings.add(pending_warning._key) + session.pending_warnings.append(pending_warning) + + +def _pending_warning(diag: "Diagnostic") -> PendingWarning: + from guppylang_internals.diagnostic import DiagnosticLevel + from guppylang_internals.span import to_span + + if diag.level is not DiagnosticLevel.WARNING: + raise InternalGuppyError("emit_warning expects a warning-level diagnostic") + + filename = None + lineno = None + column = None + if diag.span is not None: + # Python's warning machinery wants file/line information separately rather + # than Guppy's richer span object. + span = to_span(diag.span) + filename = span.start.file + lineno = span.start.line + column = span.start.column + + message = _warning_message(diag) + return PendingWarning( + diagnostic=diag, + # Deduplicate on source location plus rendered message so repeated reports from + # the same site collapse, while distinct warnings on one line still survive. + _key=_WarningKey(filename, lineno, column, message), + ) + + +def _emit_pending_warning(pending_warning: PendingWarning) -> None: + """Emit one queued warning via Python's warning machinery and rich stderr output.""" + + if pending_warning.filename is not None and pending_warning.lineno is not None: + warnings.warn_explicit( + pending_warning.message, + GuppyWarning, + pending_warning.filename, + pending_warning.lineno, + ) + else: + warnings.warn( + pending_warning.message, + GuppyWarning, + stacklevel=2, + ) + + session = _DIAGNOSTIC_SESSION.get() + if session is not None and session.rich_warnings: + sys.stderr.write(_render_warning(pending_warning)) + sys.stderr.write("\n") + + +def _render_warning(pending_warning: PendingWarning) -> str: + from guppylang_internals.diagnostic import DiagnosticsRenderer + from guppylang_internals.engine import DEF_STORE + + renderer = DiagnosticsRenderer(DEF_STORE.sources) + try: + renderer.render_diagnostic(pending_warning.diagnostic) + except KeyError: + return pending_warning.message + return "\n".join(renderer.buffer) + + +def _warning_message(diag: "Diagnostic") -> str: + lines = [diag.rendered_title] + if diag.rendered_span_label: + lines[0] += f": {diag.rendered_span_label}" + if diag.rendered_message: + lines.append(diag.rendered_message) + lines.extend( + [ + f"{child.level.name.lower().capitalize()}: {child.rendered_message}" + for child in diag.children + if child.rendered_message + ] + ) + return "\n".join(lines) diff --git a/guppylang/src/guppylang/__init__.py b/guppylang/src/guppylang/__init__.py index 4494a4e46..41209f65d 100644 --- a/guppylang/src/guppylang/__init__.py +++ b/guppylang/src/guppylang/__init__.py @@ -1,4 +1,5 @@ from guppylang_internals.experimental import enable_experimental_features +from guppylang_internals.warning import GuppyWarning, rich_warnings from guppylang.decorator import guppy from guppylang.module import GuppyModule @@ -8,6 +9,7 @@ __all__ = ( "GuppyModule", + "GuppyWarning", "array", "builtins", "comptime", @@ -17,6 +19,7 @@ "py", "quantum", "qubit", + "rich_warnings", ) # This is updated by our release-please workflow, triggered by this diff --git a/guppylang/src/guppylang/defs.py b/guppylang/src/guppylang/defs.py index 8b6d5203f..a23ce032d 100644 --- a/guppylang/src/guppylang/defs.py +++ b/guppylang/src/guppylang/defs.py @@ -22,6 +22,7 @@ TracingDefMixin, ) from guppylang_internals.tracing.util import hide_trace +from guppylang_internals.warning import diagnostic_report from hugr.envelope import GeneratorDesc from hugr.hugr import Hugr from hugr.metadata import HugrGenerator @@ -85,6 +86,9 @@ class GuppyDefinition(TracingDefMixin): def compile(self) -> Package: """Compile a Guppy definition to HUGR.""" + # Single-definition entrypoints rely on the wrapped engine helpers for warning + # collection. `ENGINE.compile_single()` already establishes the top-level + # diagnostic session via `@pretty_errors`. package: Package = ENGINE.compile_single(self.id).package for mod in package.modules: _update_generator_metadata(mod) @@ -92,6 +96,7 @@ def compile(self) -> Package: def check(self) -> None: """Type-check a Guppy definition.""" + # As above, warning collection is handled by the wrapped engine entrypoint. return ENGINE.check_single(self.id) @@ -243,7 +248,6 @@ def compile_entrypoint(self) -> Package: def compile_function(self) -> Package: """Compile a Guppy function definition to HUGR. - Returns: Package: The compiled package object. """ @@ -276,18 +280,26 @@ def _type_members(self) -> list[DefId]: def compile(self) -> Package: """Compile this collection of definitions into a HUGR package.""" - ENGINE.check(self.members) - # Check fills _type_members with additional members only available after - # checking, so we have to call it before compiling (without an engine reset). - pointer = ENGINE.compile(self.members + self._type_members(), reset=False) + # Unlike the single-definition helpers, a library compile spans multiple + # top-level engine calls. Keep one outer diagnostic session here so warnings + # flush once for the whole user operation rather than once per engine call. + with diagnostic_report(): + ENGINE.check(self.members) + # Check fills _type_members with additional members only available after + # checking, so we have to call it before compiling (without an engine + # reset). + pointer = ENGINE.compile(self.members + self._type_members(), reset=False) for mod in pointer.package.modules: _update_generator_metadata(mod) return pointer.package def check(self) -> None: """Type-check all contained definitions.""" - ENGINE.check(self.members) - ENGINE.check(self._type_members(), reset=False) + # Library checks can trigger more than one top-level engine check, so they need + # their own outer diagnostic session. + with diagnostic_report(): + ENGINE.check(self.members) + ENGINE.check(self._type_members(), reset=False) @dataclass(frozen=True) diff --git a/tests/diagnostics/test_warning_reporting.py b/tests/diagnostics/test_warning_reporting.py new file mode 100644 index 000000000..907ea2bee --- /dev/null +++ b/tests/diagnostics/test_warning_reporting.py @@ -0,0 +1,130 @@ +import warnings +from dataclasses import dataclass +from typing import ClassVar + +import pytest +from guppylang import GuppyWarning, rich_warnings +from guppylang_internals.diagnostic import Note, Warning +from guppylang_internals.engine import DEF_STORE +from guppylang_internals.span import Loc, Span +from guppylang_internals.warning import diagnostic_report, emit_warning + +from tests.util import guppy_warning_records + +file = "warning_test.py" + + +@dataclass(frozen=True) +class SyntheticWarning(Warning): + title: ClassVar[str] = "Synthetic warning" + span_label: ClassVar[str] = "Something suspicious happened" + message: ClassVar[str] = "Additional context for the warning" + + +@dataclass(frozen=True) +class SyntheticNote(Note): + message: ClassVar[str] = "Helpful note" + + +def make_warning() -> SyntheticWarning: + warning = SyntheticWarning(Span(Loc(file, 3, 2), Loc(file, 3, 6))) + warning.add_sub_diagnostic(SyntheticNote(None)) + return warning + + +def register_source() -> None: + DEF_STORE.sources.add_file(file, "x = 0\nx = 1\nwarn()\n") + + +def test_emit_warning_with_source_location(): + """Warnings with spans should preserve filename, line, and message details.""" + with warnings.catch_warnings(record=True) as records: + warnings.simplefilter("always") + with diagnostic_report(): + emit_warning(make_warning()) + + guppy_records = guppy_warning_records(records) + assert len(guppy_records) == 1 + warning = guppy_records[0] + assert warning.category is GuppyWarning + assert warning.filename == file + assert warning.lineno == 3 + assert str(warning.message) == ( + "Synthetic warning: Something suspicious happened\n" + "Additional context for the warning\n" + "Note: Helpful note" + ) + + +def test_nested_reports_flush_on_outer_exit(): + """Nested reporting sessions should flush only when the outermost session exits.""" + with warnings.catch_warnings(record=True) as records: + warnings.simplefilter("always") + with diagnostic_report(): + with diagnostic_report(): + emit_warning(make_warning()) + assert records == [] + assert records == [] + + guppy_records = guppy_warning_records(records) + assert len(guppy_records) == 1 + assert str(guppy_records[0].message).startswith("Synthetic warning") + + +def test_duplicate_warnings_are_deduplicated(): + """The same warning emitted twice in one session should only be reported once.""" + with warnings.catch_warnings(record=True) as records: + warnings.simplefilter("always") + with diagnostic_report(): + emit_warning(make_warning()) + emit_warning(make_warning()) + + guppy_records = guppy_warning_records(records) + assert len(guppy_records) == 1 + + +def test_warning_is_discarded_if_operation_fails(): + """Buffered warnings should be dropped if the enclosing operation raises.""" + + def fail_with_warning() -> None: + with diagnostic_report(): + emit_warning(make_warning()) + raise RuntimeError("boom") + + with warnings.catch_warnings(record=True) as records: + warnings.simplefilter("always") + with pytest.raises(RuntimeError, match="boom"): + fail_with_warning() + + guppy_records = guppy_warning_records(records) + assert len(guppy_records) == 0 + + +def test_rich_warnings_render_to_stderr(capsys): + """Rich warnings should preserve Python warnings and also render diagnostics.""" + register_source() + with warnings.catch_warnings(record=True) as records: + warnings.simplefilter("always") + with rich_warnings(), diagnostic_report(): + emit_warning(make_warning()) + + guppy_records = guppy_warning_records(records) + assert len(guppy_records) == 1 + err = capsys.readouterr().err + assert "Warning: Synthetic warning" in err + assert "3 |" in err + assert "Something suspicious happened" in err + + +def test_nested_rich_warnings_do_not_duplicate_stderr(capsys): + """Nested rich-warning scopes should still render exactly once.""" + register_source() + with warnings.catch_warnings(record=True) as records: + warnings.simplefilter("always") + with rich_warnings(), rich_warnings(), diagnostic_report(): + emit_warning(make_warning()) + + guppy_records = guppy_warning_records(records) + assert len(guppy_records) == 1 + err = capsys.readouterr().err + assert err.count("Warning: Synthetic warning") == 1 diff --git a/tests/error/util.py b/tests/error/util.py index 598cb7c54..8cffa5115 100644 --- a/tests/error/util.py +++ b/tests/error/util.py @@ -1,15 +1,15 @@ -import importlib.util +import importlib import inspect import pathlib import re import sys import pytest +from guppylang_internals.decorator import custom_type +from guppylang_internals.diagnostic import DiagnosticsRenderer, wrap from hugr import tys from hugr.tys import TypeBound -from guppylang_internals.decorator import custom_type -from guppylang_internals.diagnostic import DiagnosticsRenderer, wrap from tests.util import get_wasm_file # Regular expression to match the `~~~~~^^^~~~` highlights that are printed in @@ -40,7 +40,7 @@ def filter_traceback_not_containing(s: str, disallowed_regex: re.Pattern[str]) - def run_error_test(file, capsys, snapshot): file = pathlib.Path(file) - with pytest.raises(Exception) as exc_info: + with pytest.raises(Exception) as exc_info: # noqa: PT011 importlib.import_module(f"tests.error.{file.parent.name}.{file.stem}") # Remove the importlib frames from the traceback by skipping beginning frames until diff --git a/tests/integration/notebooks/rich_warnings.ipynb b/tests/integration/notebooks/rich_warnings.ipynb new file mode 100644 index 000000000..5a2b0d777 --- /dev/null +++ b/tests/integration/notebooks/rich_warnings.ipynb @@ -0,0 +1,79 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "imports", + "metadata": {}, + "outputs": [], + "source": [ + "from dataclasses import dataclass\n", + "from typing import ClassVar\n", + "\n", + "from guppylang import rich_warnings\n", + "from guppylang_internals.diagnostic import Warning\n", + "from guppylang_internals.engine import DEF_STORE\n", + "from guppylang_internals.warning import diagnostic_report, emit_warning\n", + "from guppylang_internals.span import Loc, Span\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "warning", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Warning: Notebook warning (at rich_warning_notebook.guppy:1:0)\n", + " | \n", + "1 | suspicious_code()\n", + " | ^^^^^^^^^^^^^^^^^ This warning was rendered from a notebook\n" + ] + } + ], + "source": [ + "import warnings\n", + "\n", + "@dataclass(frozen=True)\n", + "class NotebookWarning(Warning):\n", + " title: ClassVar[str] = \"Notebook warning\"\n", + " span_label: ClassVar[str] = \"This warning was rendered from a notebook\"\n", + "\n", + "\n", + "file = \"rich_warning_notebook.guppy\"\n", + "DEF_STORE.sources.add_file(file, \"suspicious_code()\\n\")\n", + "\n", + "with warnings.catch_warnings(record=True) as records:\n", + " warnings.simplefilter(\"always\")\n", + " with rich_warnings(), diagnostic_report():\n", + " emit_warning(NotebookWarning(Span(Loc(file, 1, 0), Loc(file, 1, 17))))\n", + "\n", + "assert len(records) == 1\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "guppylang (3.13.11)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/integration/test_warning_public_api.py b/tests/integration/test_warning_public_api.py new file mode 100644 index 000000000..d72620e26 --- /dev/null +++ b/tests/integration/test_warning_public_api.py @@ -0,0 +1,168 @@ +import warnings +from dataclasses import dataclass +from types import SimpleNamespace +from typing import ClassVar + +import pytest +from guppylang import rich_warnings +from guppylang.defs import GuppyDefinition, GuppyLibrary +from guppylang_internals.definition.common import DefId, Definition +from guppylang_internals.diagnostic import Error, Warning +from guppylang_internals.engine import ENGINE +from guppylang_internals.engine import DEF_STORE +from guppylang_internals.error import GuppyError +from guppylang_internals.span import Loc, Span +from guppylang_internals.warning import emit_warning +from tests.util import guppy_warning_records + +file = "public_warning_test.py" + + +@dataclass(frozen=True) +class DummyDefinition(Definition): + @property + def description(self) -> str: + return "definition" + + +@dataclass(frozen=True) +class PublicApiWarning(Warning): + title: ClassVar[str] = "Public API warning" + span_label: ClassVar[str] = "Triggered from a public entrypoint" + + +@dataclass(frozen=True) +class PublicApiError(Error): + title: ClassVar[str] = "Public API error" + span_label: ClassVar[str] = "Triggered from a public entrypoint" + + +def make_definition() -> GuppyDefinition: + return GuppyDefinition(DummyDefinition(DefId.fresh(), "dummy", None)) + + +def make_warning() -> PublicApiWarning: + return PublicApiWarning(Span(Loc(file, 5, 1), Loc(file, 5, 4))) + + +def make_error() -> PublicApiError: + return PublicApiError(Span(Loc(file, 8, 1), Loc(file, 8, 4))) + + +def register_source() -> None: + DEF_STORE.sources.add_file(file, "line1\nline2\nline3\nline4\nwarn()\nline6\nerr\n") + + +def install_check_warning(monkeypatch) -> None: + """Synthesize a warning from the inner engine `check()` implementation.""" + + def fake_check(_def_ids, *, reset=True) -> None: + del reset + emit_warning(make_warning()) + + monkeypatch.setattr(ENGINE, "check", fake_check) + + +def install_compile_warning(monkeypatch) -> None: + """Synthesize a warning from the inner engine `_compile()` implementation.""" + + def fake_compile(_def_ids, *, reset=True): + del reset + emit_warning(make_warning()) + pointer = SimpleNamespace(package=SimpleNamespace(modules=[])) + return pointer, [None] + + monkeypatch.setattr(ENGINE, "_compile", fake_compile) + + +@pytest.mark.parametrize( + ("install_warning", "run_entrypoint"), + [ + ( + install_check_warning, + lambda definition: definition.check(), + ), + ( + install_compile_warning, + lambda definition: definition.compile(), + ), + ], +) +def test_single_definition_entrypoints_emit_warning( + monkeypatch, install_warning, run_entrypoint +): + """Single-definition public entrypoints should flush one warning.""" + definition = make_definition() + install_warning(monkeypatch) + + with warnings.catch_warnings(record=True) as records: + warnings.simplefilter("always") + run_entrypoint(definition) + + guppy_records = guppy_warning_records(records) + assert len(guppy_records) == 1 + assert guppy_records[0].filename == file + + +def test_library_compile_emits_warning_once(monkeypatch): + """`GuppyLibrary.compile()` should coalesce warnings across its subcalls.""" + library = GuppyLibrary([]) + install_check_warning(monkeypatch) + + def fake_compile(_def_ids, *, reset=True): + del reset + emit_warning(make_warning()) + return SimpleNamespace(package=SimpleNamespace(modules=[])) + + monkeypatch.setattr(ENGINE, "compile", fake_compile) + + with warnings.catch_warnings(record=True) as records: + warnings.simplefilter("always") + library.compile() + + guppy_records = guppy_warning_records(records) + assert len(guppy_records) == 1 + + +def test_definition_check_discards_warning_on_error(monkeypatch): + """Top-level failures should suppress buffered warnings instead of leaking them.""" + definition = make_definition() + + def fake_check(_def_ids, *, reset=True) -> None: + del reset + emit_warning(make_warning()) + raise GuppyError(make_error()) + + monkeypatch.setattr(ENGINE, "check", fake_check) + + with warnings.catch_warnings(record=True) as records: + warnings.simplefilter("always") + with pytest.raises(GuppyError): + definition.check() + + guppy_records = guppy_warning_records(records) + assert len(guppy_records) == 0 + + +def test_library_compile_rich_warning_emits_stderr_once(monkeypatch, capsys): + """Rich mode should not duplicate rendered warnings across library subcalls.""" + library = GuppyLibrary([]) + register_source() + install_check_warning(monkeypatch) + + def fake_compile(_def_ids, *, reset=True): + del reset + emit_warning(make_warning()) + return SimpleNamespace(package=SimpleNamespace(modules=[])) + + monkeypatch.setattr(ENGINE, "compile", fake_compile) + + with warnings.catch_warnings(record=True) as records: + warnings.simplefilter("always") + with rich_warnings(): + library.compile() + + guppy_records = guppy_warning_records(records) + assert len(guppy_records) == 1 + err = capsys.readouterr().err + assert err.count("Warning: Public API warning") == 1 diff --git a/tests/util.py b/tests/util.py index bef29aeb8..6ec0a0de5 100644 --- a/tests/util.py +++ b/tests/util.py @@ -3,9 +3,12 @@ from pathlib import Path from typing import TYPE_CHECKING +from guppylang import GuppyWarning from guppylang.decorator import custom_guppy_decorator, guppy if TYPE_CHECKING: + from warnings import WarningMessage + from guppylang.defs import GuppyFunctionDefinition from hugr.package import Package, PackagePointer @@ -34,3 +37,7 @@ def get_wasm_file() -> str: def get_h2_wasm_file() -> str: return str(Path(__file__).parent.resolve() / "resources/test.h2.wasm") + + +def guppy_warning_records(records: list[WarningMessage]) -> list[WarningMessage]: + return [warning for warning in records if warning.category is GuppyWarning]