Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ markers = [
'requires_atlas: tests that require `atlas4py` bindings package',
'requires_dace: tests that require `dace` package',
'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)',
'requires_jax: tests that require `jax` package',
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated fix that I sneak in...

'uses_applied_shifts: tests that require backend support for applied-shifts',
'uses_can_deref: tests that require backend support for can_deref builtin function',
'uses_composite_shifts: tests that use composite shifts in unstructured domain',
Expand Down
19 changes: 14 additions & 5 deletions src/gt4py/next/otf/compiled_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ def compile_variant_hook(
},
)

if __debug__:
warnings.warn(
"Python is not running in optimized mode, which may impact performance when using a"
" compiled backend. Consider running with `python -O` or setting the environment"
" variable `PYTHONOPTIMIZE=1`.",
stacklevel=6,
)


@hook_machinery.context_hook
def compiled_program_call_context(
Expand Down Expand Up @@ -379,11 +387,12 @@ def __call__(
# type, add the argument types to the cache key as the argument types are used during
# compilation. In case the program is not generic we can avoid the potentially
# expensive type deduction for all arguments and not include it in the key.
warnings.warn(
"Calling generic programs / direct calls to scan operators are not optimized. "
"Consider calling a specialized version instead.",
stacklevel=2,
)
if enable_jit:
warnings.warn(
"Calling generic programs / direct calls to scan operators are not optimized. "
"Consider calling a specialized version instead.",
stacklevel=3,
)
arg_specialization_key = eve_utils.content_hash(
(
tuple(type_translation.from_value(arg) for arg in canonical_args),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import inspect
import warnings
from typing import Optional, NamedTuple
from unittest import mock

Expand Down Expand Up @@ -991,3 +993,75 @@ def testee(inp: tuple[cases.IField, cases.IField, float], out: NamedTupleNamedCo
arguments.FieldDomainDescriptor(out[1].domain),
),
}


def test_warn_if_not_optimized_on_jit_call(cartesian_case, compile_testee):
"""A warning is emitted when calling a compiled program without Python's -O flag."""
if cartesian_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

args, kwargs = cases.get_default_data(cartesian_case, compile_testee)

with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
call_lineno = inspect.currentframe().f_lineno + 1
compile_testee(*args, offset_provider=cartesian_case.offset_provider, **kwargs)

optimized_warnings = [w for w in caught if "optimized" in str(w.message)]
assert len(optimized_warnings) == 1
w = optimized_warnings[0]
assert w.filename == __file__
assert w.lineno == call_lineno


def test_warn_if_not_optimized_on_explicit_compile(cartesian_case, compile_testee):
"""A warning is emitted when pre-compiling a program without Python's -O flag."""
if cartesian_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
call_lineno = inspect.currentframe().f_lineno + 1
compile_testee.compile(offset_provider=cartesian_case.offset_provider)

optimized_warnings = [w for w in caught if "optimized" in str(w.message)]
assert len(optimized_warnings) == 1
w = optimized_warnings[0]
assert w.filename == __file__
assert w.lineno == call_lineno


@pytest.fixture
def scan_operator_testee(cartesian_case):
"""A scan operator called directly as a FieldOperator — the only case where _is_generic=True."""
if cartesian_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

@gtx.scan_operator(axis=KDim, forward=True, init=0, backend=cartesian_case.backend)
def testee(carry: gtx.int32, inp: gtx.int32) -> gtx.int32:
return carry + inp

return testee


@pytest.mark.uses_scan
def test_warn_on_direct_scan_operator_call(cartesian_case, scan_operator_testee):
"""A warning is emitted when a scan operator is called directly as a FieldOperator.

Scan operators called directly (not wrapped in a @program) are the only case where
_is_generic is True, triggering the 'not optimized' generic-program warning.
"""
k_size = cartesian_case.default_sizes[KDim]
inp = cartesian_case.as_field([KDim], np.arange(k_size, dtype=np.int32))
out = cartesian_case.as_field([KDim], np.zeros(k_size, dtype=np.int32))

with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
call_lineno = inspect.currentframe().f_lineno + 1
scan_operator_testee(inp, out=out, offset_provider=cartesian_case.offset_provider)

generic_warnings = [w for w in caught if "generic" in str(w.message)]
assert len(generic_warnings) == 1
w = generic_warnings[0]
assert w.filename == __file__
assert w.lineno == call_lineno
Loading