Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ markers = [
'requires_atlas: tests that require `atlas4py` bindings package',
'requires_dace: tests that require `dace` package',
'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)',
'requires_jax: tests that require `jax` package',
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated fix that I sneak in...

'uses_applied_shifts: tests that require backend support for applied-shifts',
'uses_can_deref: tests that require backend support for can_deref builtin function',
'uses_composite_shifts: tests that use composite shifts in unstructured domain',
Expand Down
21 changes: 16 additions & 5 deletions src/gt4py/next/otf/compiled_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ def compile_variant_hook(
},
)

if __debug__:
# Note: We set the stack level to point to something internally as we don't want to show this warning more than once.
# It's an ad-hoc pragramatic choice that could be revisited in the future.
warnings.warn(
"Python is not running in optimized mode, which may impact performance when using a"
" compiled backend. Consider running with `python -O` or setting the environment"
" variable `PYTHONOPTIMIZE=1`.",
stacklevel=3,
)


@hook_machinery.context_hook
def compiled_program_call_context(
Expand Down Expand Up @@ -379,11 +389,12 @@ def __call__(
# type, add the argument types to the cache key as the argument types are used during
# compilation. In case the program is not generic we can avoid the potentially
# expensive type deduction for all arguments and not include it in the key.
warnings.warn(
"Calling generic programs / direct calls to scan operators are not optimized. "
"Consider calling a specialized version instead.",
stacklevel=2,
)
if enable_jit:
warnings.warn(
"Calling generic programs / direct calls to scan operators are not optimized. "
"Consider calling a specialized version instead.",
stacklevel=3,
)
arg_specialization_key = eve_utils.content_hash(
(
tuple(type_translation.from_value(arg) for arg in canonical_args),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import inspect
import warnings
from typing import Optional, NamedTuple
from unittest import mock

Expand Down Expand Up @@ -991,3 +993,39 @@ def testee(inp: tuple[cases.IField, cases.IField, float], out: NamedTupleNamedCo
arguments.FieldDomainDescriptor(out[1].domain),
),
}


@pytest.fixture
def scan_operator_testee(cartesian_case):
"""A scan operator called directly as a FieldOperator — the only case where _is_generic=True."""
if cartesian_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

@gtx.scan_operator(axis=KDim, forward=True, init=0, backend=cartesian_case.backend)
def testee(carry: gtx.int32, inp: gtx.int32) -> gtx.int32:
return carry + inp

return testee


@pytest.mark.uses_scan
def test_warn_on_direct_scan_operator_call(cartesian_case, scan_operator_testee):
"""A warning is emitted when a scan operator is called directly as a FieldOperator.

Scan operators called directly (not wrapped in a @program) are the only case where
_is_generic is True, triggering the 'not optimized' generic-program warning.
"""
k_size = cartesian_case.default_sizes[KDim]
inp = cartesian_case.as_field([KDim], np.arange(k_size, dtype=np.int32))
out = cartesian_case.as_field([KDim], np.zeros(k_size, dtype=np.int32))

with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
call_lineno = inspect.currentframe().f_lineno + 1
scan_operator_testee(inp, out=out, offset_provider=cartesian_case.offset_provider)

generic_warnings = [w for w in caught if "generic" in str(w.message)]
assert len(generic_warnings) == 1
w = generic_warnings[0]
assert w.filename == __file__
assert w.lineno == call_lineno
Loading