diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index 0e6bf4ff5d..259b68cc33 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -12,6 +12,7 @@ import enum import os import pathlib +import warnings from typing import Final @@ -124,6 +125,21 @@ def env_flag_to_int(name: str, default: int) -> int: DUMP_METRICS_AT_EXIT: str | None = None +#: Filter out DaCe related warnings. If not set warnings will be suppressed if the +#: code runs in no debug mode. +SKIP_WARNINGS: bool = env_flag_to_bool("GT4PY_SKIP_WARNINGS", default=not __debug__) + + +if SKIP_WARNINGS: + # NOTE: Ideally we would suppress the warnings using context managers directly in + # the backend. However, because this is not thread safe in Python versions before + # 3.14, we have to do it here. + warnings.filterwarnings(action="ignore", module="^dace(\..+)?") + warnings.filterwarnings( + action="ignore", module="^gt4py.next.program_processors.runners.dace(\..+)?" + ) + + def _init_dump_metrics_filename() -> str: return f"gt4py_metrics_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json" diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py b/src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py index 8052426f33..535c9cb2eb 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py @@ -1353,18 +1353,16 @@ def _handle_special_case_of_gt4py_scan_point( if _handle_special_case_of_gt4py_scan_point_impl( state, descending_point, concat_node, consumed_subset ): - if __debug__: - warnings.warn( - f"Special rule applied to `concat_where`-inline `{concat_node.data}` into `{nsdfg.label}`.", - stacklevel=1, - ) + warnings.warn( + f"Special rule applied to `concat_where`-inline `{concat_node.data}` into `{nsdfg.label}`.", + stacklevel=1, + ) return True else: - if __debug__: - warnings.warn( - f"Special rule applied to `concat_where`-inline `{concat_node.data}` into `{nsdfg.label}` was rejected.", - stacklevel=1, - ) + warnings.warn( + f"Special rule applied to `concat_where`-inline `{concat_node.data}` into `{nsdfg.label}` was rejected.", + stacklevel=1, + ) return False diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/__init__.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/__init__.py index a576665ee3..5d748e0586 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/__init__.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/__init__.py @@ -10,4 +10,4 @@ #: Attribute defining package-level marks used by a custom pytest hook. -package_pytestmarks = [pytest.mark.usefixtures("common_dace_config")] +package_pytestmarks = [pytest.mark.requires_dace, pytest.mark.usefixtures("common_dace_config")] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index f1c8da0143..2258a0931c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -20,8 +20,6 @@ from . import util -import dace - def _make_strides_propagation_level3_sdfg() -> dace.SDFG: """Generates the level 3 SDFG (nested-nested) SDFG for `test_strides_propagation()`.""" diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_warnings.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_warnings.py new file mode 100644 index 0000000000..d9665de12d --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_warnings.py @@ -0,0 +1,20 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import warnings + +from gt4py.next import config as gtx_config + + +def test_if_warning_is_raised(): + assert not gtx_config.SKIP_WARNINGS, "Tests do not run in debug mode." + + warn_msg = "This is a warning." + with pytest.warns(UserWarning, match=warn_msg): + warnings.warn(warn_msg, UserWarning)