diff --git a/pyproject.toml b/pyproject.toml index 08d5e7a1a0..15eab2877e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -284,6 +284,7 @@ markers = [ 'requires_atlas: tests that require `atlas4py` bindings package', 'requires_dace: tests that require `dace` package', 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', + 'requires_jax: tests that require `jax` package', 'uses_applied_shifts: tests that require backend support for applied-shifts', 'uses_can_deref: tests that require backend support for can_deref builtin function', 'uses_composite_shifts: tests that use composite shifts in unstructured domain', diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index b48db16f7e..03fcd1fe10 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -113,6 +113,14 @@ def compile_variant_hook( }, ) + if __debug__: + warnings.warn( + "Python is not running in optimized mode, which may impact performance when using a" + " compiled backend. Consider running with `python -O` or setting the environment" + " variable `PYTHONOPTIMIZE=1`.", + stacklevel=6, + ) + @hook_machinery.context_hook def compiled_program_call_context( @@ -379,11 +387,12 @@ def __call__( # type, add the argument types to the cache key as the argument types are used during # compilation. In case the program is not generic we can avoid the potentially # expensive type deduction for all arguments and not include it in the key. - warnings.warn( - "Calling generic programs / direct calls to scan operators are not optimized. " - "Consider calling a specialized version instead.", - stacklevel=2, - ) + if enable_jit: + warnings.warn( + "Calling generic programs / direct calls to scan operators are not optimized. " + "Consider calling a specialized version instead.", + stacklevel=3, + ) arg_specialization_key = eve_utils.content_hash( ( tuple(type_translation.from_value(arg) for arg in canonical_args), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 18c6c26ff4..702adbe127 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -5,6 +5,8 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import inspect +import warnings from typing import Optional, NamedTuple from unittest import mock @@ -991,3 +993,75 @@ def testee(inp: tuple[cases.IField, cases.IField, float], out: NamedTupleNamedCo arguments.FieldDomainDescriptor(out[1].domain), ), } + + +def test_warn_if_not_optimized_on_jit_call(cartesian_case, compile_testee): + """A warning is emitted when calling a compiled program without Python's -O flag.""" + if cartesian_case.backend is None: + pytest.skip("Embedded compiled program doesn't make sense.") + + args, kwargs = cases.get_default_data(cartesian_case, compile_testee) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + call_lineno = inspect.currentframe().f_lineno + 1 + compile_testee(*args, offset_provider=cartesian_case.offset_provider, **kwargs) + + optimized_warnings = [w for w in caught if "optimized" in str(w.message)] + assert len(optimized_warnings) == 1 + w = optimized_warnings[0] + assert w.filename == __file__ + assert w.lineno == call_lineno + + +def test_warn_if_not_optimized_on_explicit_compile(cartesian_case, compile_testee): + """A warning is emitted when pre-compiling a program without Python's -O flag.""" + if cartesian_case.backend is None: + pytest.skip("Embedded compiled program doesn't make sense.") + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + call_lineno = inspect.currentframe().f_lineno + 1 + compile_testee.compile(offset_provider=cartesian_case.offset_provider) + + optimized_warnings = [w for w in caught if "optimized" in str(w.message)] + assert len(optimized_warnings) == 1 + w = optimized_warnings[0] + assert w.filename == __file__ + assert w.lineno == call_lineno + + +@pytest.fixture +def scan_operator_testee(cartesian_case): + """A scan operator called directly as a FieldOperator — the only case where _is_generic=True.""" + if cartesian_case.backend is None: + pytest.skip("Embedded compiled program doesn't make sense.") + + @gtx.scan_operator(axis=KDim, forward=True, init=0, backend=cartesian_case.backend) + def testee(carry: gtx.int32, inp: gtx.int32) -> gtx.int32: + return carry + inp + + return testee + + +@pytest.mark.uses_scan +def test_warn_on_direct_scan_operator_call(cartesian_case, scan_operator_testee): + """A warning is emitted when a scan operator is called directly as a FieldOperator. + + Scan operators called directly (not wrapped in a @program) are the only case where + _is_generic is True, triggering the 'not optimized' generic-program warning. + """ + k_size = cartesian_case.default_sizes[KDim] + inp = cartesian_case.as_field([KDim], np.arange(k_size, dtype=np.int32)) + out = cartesian_case.as_field([KDim], np.zeros(k_size, dtype=np.int32)) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + call_lineno = inspect.currentframe().f_lineno + 1 + scan_operator_testee(inp, out=out, offset_provider=cartesian_case.offset_provider) + + generic_warnings = [w for w in caught if "generic" in str(w.message)] + assert len(generic_warnings) == 1 + w = generic_warnings[0] + assert w.filename == __file__ + assert w.lineno == call_lineno