Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"""Fast access to the auto optimization on DaCe."""

import enum
import warnings
from typing import Any, Callable, Optional, Sequence, TypeAlias, Union

import dace
Expand Down Expand Up @@ -280,15 +279,15 @@ def gt_auto_optimize(
if demote_fields is not None:
for field_to_demote in demote_fields:
if field_to_demote not in sdfg.arrays:
warnings.warn(
gtx_transformations.utils.warn(
f"Requested the demotion of field '{field_to_demote}' but the field is unknown.",
stacklevel=0,
)
continue
field_desc = sdfg.arrays[field_to_demote]

if field_desc.transient:
warnings.warn(
gtx_transformations.utils.warn(
f"Requested the demotion of field '{field_to_demote}' but the field is a transient.",
stacklevel=0,
)
Expand Down Expand Up @@ -369,7 +368,7 @@ def gt_auto_optimize(
sdfg.arrays[demoted_field].transient = False

else:
warnings.warn(
gtx_transformations.utils.warn(
f"Could not restore the demoted field '{demoted_field}' back to a global.",
stacklevel=0,
)
Expand Down Expand Up @@ -876,7 +875,7 @@ def _gt_auto_configure_maps_and_strides(

if gpu:
if unit_strides_kind != gtx_common.DimensionKind.HORIZONTAL:
warnings.warn(
gtx_transformations.utils.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is really unexpected, we could also throw an exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is more a super serious performance warning but (should) not affect correctness.
I think I will not handle it in a special way.

"The GT4Py DaCe GPU backend assumes that the leading dimension, i.e."
" where stride is 1, is of kind 'HORIZONTAL', however it was"
f" '{unit_strides_kind}'. Furthermore, it should be the last dimension."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import copy
import dataclasses
import functools
import warnings
from typing import Any, Collection, Literal, Mapping, Optional, Sequence, TypeAlias, Union, overload

import dace
Expand Down Expand Up @@ -1353,18 +1352,16 @@ def _handle_special_case_of_gt4py_scan_point(
if _handle_special_case_of_gt4py_scan_point_impl(
state, descending_point, concat_node, consumed_subset
):
if __debug__:
warnings.warn(
f"Special rule applied to `concat_where`-inline `{concat_node.data}` into `{nsdfg.label}`.",
stacklevel=1,
)
gtx_transformations.utils.warn(
f"Special rule applied to `concat_where`-inline `{concat_node.data}` into `{nsdfg.label}`.",
stacklevel=1,
)
return True
else:
if __debug__:
warnings.warn(
f"Special rule applied to `concat_where`-inline `{concat_node.data}` into `{nsdfg.label}` was rejected.",
stacklevel=1,
)
gtx_transformations.utils.warn(
f"Special rule applied to `concat_where`-inline `{concat_node.data}` into `{nsdfg.label}` was rejected.",
stacklevel=1,
)
return False


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from __future__ import annotations

import copy
import warnings
from typing import Any, Callable, Final, Optional, Sequence, Union

import dace
Expand Down Expand Up @@ -622,14 +621,14 @@ def __init__(
if block_size_1d is not None:
self.block_size_1d = block_size_1d
if self.block_size_1d[1] != 1 or self.block_size_1d[2] != 1:
warnings.warn(
gtx_transformations.utils.warn(
f"1D map block size specified with more than one dimension larger than 1. Configured 1D block size: {self.block_size_1d}.",
stacklevel=0,
)
if block_size_2d is not None:
self.block_size_2d = block_size_2d
if self.block_size_2d[2] != 1:
warnings.warn(
gtx_transformations.utils.warn(
f"2D map block size specified with more than twi dimensions larger than 1. Configured 2D block size: {self.block_size_2d}.",
stacklevel=0,
)
Expand Down Expand Up @@ -707,7 +706,7 @@ def apply(

if is_degenerated_1d_map:
num_map_params = 1
warnings.warn(
gtx_transformations.utils.warn(
f"Map '{gpu_map}', size '{map_size}', is a degenerated 1d Map. Handle it as a 1d Map.",
stacklevel=0,
)
Expand All @@ -728,7 +727,7 @@ def apply(
)
block_size[block_size_1D_index] = self.block_size_1d[0]
if block_size_1D_index != 0:
warnings.warn(
gtx_transformations.utils.warn(
f"Blocksize of 1d Map '{gpu_map}' was set to {block_size}, but the iteration index is not the x dimension.",
stacklevel=0,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
import warnings
from typing import Final, Iterable, Optional, Sequence, TypeAlias

import dace
Expand Down Expand Up @@ -564,7 +563,7 @@ def _populate_nested_sdfg(
"Connections to a non replicated node are only allowed if the source node is an AccessNode"
)

warnings.warn(
gtx_transformations.utils.warn(
"Detected computation of data that might not be needed in inline fuser.",
stacklevel=0,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
import warnings
from typing import Any, Callable, Mapping, Optional, TypeAlias, Union

import dace
Expand Down Expand Up @@ -211,7 +210,7 @@ def __init__(
self._bypass_fusion_test = False

if not self.fuse_after_promotion:
warnings.warn(
gtx_transformations.utils.warn(
"Created a `MapPromoter` that does not fuse immediately, which might lead to borderline invalid SDFGs.",
stacklevel=1,
)
Expand Down Expand Up @@ -286,7 +285,7 @@ def can_be_applied(
if (second_map_iterations > 0) != True: # noqa: E712 [true-false-comparison] # SymPy fuzzy bools.
return False
else:
warnings.warn(
gtx_transformations.utils.warn(
"Was unable to determine if the second Map ({second_map_entry}) is executed.",
stacklevel=0,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# be removed because the same data are copied from the first all the way to fourth node.
# The split should be done using SplitAccessNode transformation.

import warnings
from typing import Any, Optional

import dace
Expand All @@ -26,6 +25,7 @@
from dace.sdfg import nodes as dace_nodes
from dace.transformation.passes import analysis as dace_analysis

from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations
from gt4py.next.program_processors.runners.dace.transformations import (
splitting_tools as gtx_dace_split,
strides as gtx_transformations_strides,
Expand Down Expand Up @@ -152,14 +152,14 @@ def can_be_applied(

first_node_writes_subset = gtx_dace_split.subset_merger(ranges_written_to_first_node)
if len(first_node_writes_subset) != 1:
warnings.warn(
gtx_transformations.utils.warn(
"[RemoveAccessNodeCopies] The range of writes to the first node is not a single range.",
stacklevel=0,
)
return False
fourth_node_writes_subset = gtx_dace_split.subset_merger(ranges_written_to_fourth_node)
if len(fourth_node_writes_subset) != 1:
warnings.warn(
gtx_transformations.utils.warn(
"[RemoveAccessNodeCopies] The range of writes to the fourth node is not a single range.",
stacklevel=0,
)
Expand All @@ -170,13 +170,13 @@ def can_be_applied(
[union_written_to_first_node_data, union_written_to_fourth_node_data]
)
if len(union_written_to_common_data) != 1:
warnings.warn(
gtx_transformations.utils.warn(
"[RemoveAccessNodeCopies] The union of the ranges written to the first and fourth nodes is not a single range.",
stacklevel=0,
)
return False
if union_written_to_common_data != first_node_range:
warnings.warn(
gtx_transformations.utils.warn(
"[RemoveAccessNodeCopies] The whole range of the first node is not written.",
stacklevel=0,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import collections
import copy
import uuid
import warnings
from typing import Any, Iterable, Optional, TypeAlias

import dace
Expand Down Expand Up @@ -371,7 +370,7 @@ def gt_inline_nested_sdfg(
nb_inlines_total += 1
if nsdfg_node.label.startswith("scan_"):
# See `gtir_to_sdfg_scan.py::translate_scan()` for more information.
warnings.warn(
gtx_transformations.utils.warn(
f"Inlined '{nsdfg_node.label}' which might be a scan, this might leads to errors during simplification.",
stacklevel=0,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import warnings
from typing import Any, Iterable, Optional

import dace
Expand Down Expand Up @@ -361,7 +360,7 @@ def _find_edge_reassignment(
if unused_producers:
# This situation is generated by MapFusion, if the intermediate
# AccessNode has to be kept alive.
warnings.warn(
gtx_transformations.utils.warn(
"'SplitAccessNode': found producers "
+ ", ".join((str(p) for p in unused_producers))
+ " that generates data but that is never read.",
Expand Down Expand Up @@ -413,7 +412,7 @@ def _find_producer(
# one Tasklet writes `T[__i, 0]` the other `T[__i, 10]`, where `__i`
# is the iteration index. Then Memlet propagation will set the subset
# to something like `T[:, 0:10]`. So it is not an error in that case.
warnings.warn(
gtx_transformations.utils.warn(
f"Found transient '{self.access_node.data}' that has multiple overlapping"
" incoming edges. Might indicate an error.",
stacklevel=0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import warnings
from typing import Any, Optional

import dace
Expand Down Expand Up @@ -94,7 +93,7 @@ def can_be_applied(
# TODO(phimuell): For optimal result we should fuse these edges first.
src_to_dest_edges = list(graph.edges_between(src_node, dst_node))
if len(src_to_dest_edges) != 1:
warnings.warn(
gtx_transformations.utils.warn(
f"Found multiple edges between '{src_node.data}' and '{dst_node.data}'",
stacklevel=0,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import warnings
from typing import Any

import dace
from dace import transformation as dace_transformation
from dace.sdfg import nodes as dace_nodes, utils as dace_sdutils

from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations


# Conditional import because `gt4py.cartesian` uses an older DaCe version without
# `explicit_cf_compatible`.
Expand Down Expand Up @@ -222,7 +223,7 @@ def _check_for_read_write_dependencies(
# AccessNode to the same data. In case it is global this is most likely
# valid. However, I think that simply allow it, is not okay, because
# it might break some assumption in the fuse code.
warnings.warn(
gtx_transformations.utils.warn(
f"Detected that '{first_state}' writes to the data"
f" `{', '.join(all_data_producers.intersection(data_producers[-1]))}`"
" in multiple concurrent subgraphs. This might indicate an error.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

"""Common functionality for the transformations/optimization pipeline."""

import functools
import uuid
from typing import Optional, Sequence, TypeVar, Union
import warnings
from typing import Any, Optional, Sequence, TypeVar, Union

import dace
from dace import data as dace_data, libraries as dace_lib, subsets as dace_sbs, symbolic as dace_sym
Expand All @@ -21,6 +23,25 @@
_PassT = TypeVar("_PassT", bound=dace_ppl.Pass)


@functools.wraps(warnings.warn)
def warn(
message: str,
category: type[Warning] | None = None,
stacklevel: int = 1,
source: Any | None = None,
) -> None:
"""Wrapper around `warnings.warn()` function that is only enabled in debug mode."""
if __debug__:
# NOTE: The `skip_file_prefixes` argument was introduced in Python 3.12 and is
# ignored.
warnings.warn(
message=message,
category=category,
stacklevel=(stacklevel + 1),
source=source,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for __debug__ should be done only once at import time. I also think this can be simplified much further (and also maybe renamed to avoid misunderstandings, although it is not critical):

Suggested change
@functools.wraps(warnings.warn)
def warn(
message: str,
category: type[Warning] | None = None,
stacklevel: int = 1,
source: Any | None = None,
) -> None:
"""Wrapper around `warnings.warn()` function that is only enabled in debug mode."""
if __debug__:
# NOTE: The `skip_file_prefixes` argument was introduced in Python 3.12 and is
# ignored.
warnings.warn(
message=message,
category=category,
stacklevel=(stacklevel + 1),
source=source,
)
if __debug__:
from warnings import warn as debug_warn
else:
@functools.wraps(warnings.warn)
def debug_warn(*args, **kwargs): -> None:
pass

Finally, I think this is not the right place for this definition. Consider moving it to something like gt4py.next.utils.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I've just realized that this redefinition might not be even needed. We could just use the standard warnings.warn function and install a filter for gt4py-produced warnings if __debug__ is false, in the same way I proposed for DaCe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that we have multiple workers and as I understand the documentation doing this is only safe for Python >=3.14, otherwise we would have a data race.

Thus for Pythoon older than 3.14 we would need to inject the filter before we start the workers and maintain the filter (assuming that we use a context) until we are done.
We would thus not only silence the DaCe warnings but pretty much everything.

Copy link
Contributor

@egparedes egparedes Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way I understand the documentation is that using context handlers is not thread safe, but I'm proposing to add a global filter at import time for warnings coming from gt4py and dace modules. I don't see why this would be a problem.

For more explicit user control, we could also add a config option/env var GT4PY_SKIP_WARNINGS=0/1 which can be explicitly enabled by the user and otherwise default to filter warnings if not in debug mode. Something conceptually like: skip_warnings = bool(os.env.get(GT4PY_SKIP_WARNINGS, not __debug__))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is possible I agree, you just have to make sure that this filter is installed before you start the threads and that you never get rid of the filter.

Do you have an idea where, i.e. which file, we should install this filter?
Directly inside the config.py or is there a better location?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine adding this to config.py as a quick workaround for now. It can be cleaned up later in the PR with the new config system.



def unique_name(name: str) -> str:
"""Adds a unique string to `name`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@


#: Attribute defining package-level marks used by a custom pytest hook.
package_pytestmarks = [pytest.mark.usefixtures("common_dace_config")]
package_pytestmarks = [pytest.mark.requires_dace, pytest.mark.usefixtures("common_dace_config")]
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

from . import util

import dace


def _make_strides_propagation_level3_sdfg() -> dace.SDFG:
"""Generates the level 3 SDFG (nested-nested) SDFG for `test_strides_propagation()`."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import pytest
import numpy as np
import copy

# Without this the test fails if DaCe is not installed, even when the `requires_dace`
# marker is configured in `__init__.py`.
dace = pytest.importorskip("dace")

from gt4py.next.program_processors.runners.dace import (
transformations as gtx_transformations,
)


def test_if_warning_is_raised():
warn_msg = "This is a warning."

with pytest.warns(UserWarning, match=warn_msg):
gtx_transformations.utils.warn(warn_msg, UserWarning)
Loading