diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md index 785cc0b24d..803b4c7dd5 100644 --- a/docs/user/next/advanced/HackTheToolchain.md +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -64,7 +64,7 @@ class PureCpp2WorkflowFactory(gtx.program_processors.runners.gtfn.GTFNCompileWor ) -PureCpp2WorkflowFactory(cmake_build_type=gtx.config.CMAKE_BUILD_TYPE.DEBUG) +PureCpp2WorkflowFactory(cmake_build_type=gtx.config.cmake_build_type.DEBUG) ``` ## Invent new Workflow Types diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index a0e48ae557..10454b1b44 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -88,7 +88,9 @@ def first(iterable: Iterable[T], *, default: Union[T, NothingType] = NOTHING) -> raise error -def isinstancechecker(type_info: Union[Type, Iterable[Type]]) -> Callable[[Any], bool]: +def isinstancechecker( + type_info: Union[Type, Iterable[Type], types.UnionType], +) -> Callable[[Any], bool]: """Return a callable object that checks if operand is an instance of `type_info`. Examples: @@ -101,18 +103,22 @@ def isinstancechecker(type_info: Union[Type, Iterable[Type]]) -> Callable[[Any], False """ - types: Tuple[Type, ...] = tuple() + accepted_types: Tuple[Type, ...] = tuple() if isinstance(type_info, type): - types = (type_info,) + accepted_types = (type_info,) + elif isinstance(type_info, types.UnionType): + accepted_types = type_info.__args__ elif not isinstance(type_info, tuple) and is_collection(type_info): - types = tuple(type_info) + accepted_types = tuple(type_info) else: - types = type_info # type:ignore # it is checked at run-time + accepted_types = type_info # type:ignore # it is checked at run-time - if not isinstance(types, tuple) or not all(isinstance(t, type) for t in types): - raise ValueError(f"Invalid type(s) definition: '{types}'.") + if not isinstance(accepted_types, tuple) or not all( + isinstance(t, type) for t in accepted_types + ): + raise ValueError(f"Invalid type(s) definition: '{accepted_types}'.") - return lambda obj: isinstance(obj, types) + return lambda obj: isinstance(obj, accepted_types) def attrchecker(*names: str) -> Callable[[Any], bool]: @@ -527,6 +533,84 @@ def partial(self, *args: Any, **kwargs: Any) -> fluid_partial: return fluid_partial(self, *args, **kwargs) +class TypeMapping(collections.abc.Mapping[type, _T]): + """ + A mapping from types to values supporting complex type-based dispatching. + + The mapping supports registering values for specific types, and + retrieving values based on the type key, supporting subtyping + relationship exactly in the same way as `functools.singledispatch()` works. + For example, if a value is registered for a base class, it will be returned + for instances of derived classes unless a more specific type is registered. + + Examples: + >>> mapping = TypeMapping(lambda type_: f"Default for {type_}") + >>> mapping[int] = "Integer handler" + >>> mapping[int] + 'Integer handler' + >>> mapping[float] + "Default for " + + >>> import collections + >>> mapping[tuple] = "Tuple handler" + >>> mapping[tuple] + 'Tuple handler' + >>> mapping[collections.namedtuple("Point", ["x", "y"])] + 'Tuple handler' + """ + + def __init__(self, fallback_factory: Callable[[type], _T]) -> None: + self._fallback_factory = fallback_factory + self._dispatcher = functools.singledispatch(self._fallback_factory) + + def __getitem__(self, type_: type) -> _T: + dispatched = self._dispatcher.dispatch(type_) + return ( + self._fallback_factory(type_) + if dispatched is self._fallback_factory + else cast(_T, dispatched) + ) + + def __setitem__(self, type_: type, value: _T) -> None: + self._dispatcher.register(type_, value) # type: ignore[call-overload] # abusine singledispatch to register any value, not just callables + self.clear_cache() + + def __iter__(self) -> Iterator[type]: + return iter(self._dispatcher.registry) + + def __len__(self) -> int: + return len(self._dispatcher.registry) + + def __contains__(self, type_: object) -> bool: + """Check if a type is registered in the mapping (including via superclasses).""" + return self._dispatcher.dispatch(type_) is not self._fallback_factory + + @overload + def register(self, type_: type, value: _T) -> _T: ... + + @overload + def register(self, type_: type, value: NothingType = NOTHING) -> Callable[[_T], _T]: ... + + def register(self, type_: type, value: _T | NothingType = NOTHING) -> _T | Callable[[_T], _T]: + """Return a decorator to register a value for the given type.""" + + if value is not NOTHING: + assert not isinstance(value, NothingType) + self[type_] = value + return value + else: + + def _decorator(value: _T) -> _T: + self[type_] = value + return value + + return _decorator + + def clear_cache(self) -> None: + """Clear the type dispatching cache.""" + self._dispatcher._clear_cache() + + @overload def with_fluid_partial( func: Literal[None] = None, *args: Any, **kwargs: Any diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index daa56190dd..0e9e5afec0 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -19,6 +19,12 @@ """ # ruff: noqa: F401 +from __future__ import annotations + + +# reexport the actual configuration manager instance as a public attribute +from ._config import Config as config_type, config # ruff: isort: skip + from .._core.definitions import CUPY_DEVICE_TYPE, Device, DeviceType, is_scalar_type from . import common, ffront, iterator, program_processors, typing from .common import ( @@ -52,6 +58,8 @@ __all__ = [ # submodules "common", + "config", + "config_type", "ffront", "iterator", "program_processors", diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py new file mode 100644 index 0000000000..75681c0705 --- /dev/null +++ b/src/gt4py/next/_config.py @@ -0,0 +1,555 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +""" +GT4Py configuration system. + +This module defines a configuration system based on these concepts: + +- `OptionDescriptor`: full description of an option (type, default/default_factory, + parser, validator, environment variable mapping, and optional update callback). +- `ConfigManager`: stores option values, supports task-local temporary overrides, + and resolves effective values using precedence. +- `Config`: concrete registry of GT4Py public options. + +Configuration can be changed globally in a ConfigManager instance via attribute +assignment or `set()`, and temporarily via `overrides()`. + +The global GT4Py ConfigManager instance is exposed as `gt4py.next.config`. +""" + +from __future__ import annotations + +import contextlib +import contextvars +import dataclasses +import enum +import os +import pathlib +import sys +import types +from collections.abc import Callable, Generator, Mapping +from typing import Any, Final, Generic, Literal, Protocol, TypeVar, cast, final, overload + +from gt4py.eve import utils + + +@final +class _UnsetSentinel: + _instance: _UnsetSentinel | None = None + + def __new__(cls) -> _UnsetSentinel: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + +UNSET: Final[_UnsetSentinel] = _UnsetSentinel() + + +_T = TypeVar("_T") +_T_contra = TypeVar("_T_contra", contravariant=True) + + +def parse_env_var( + var_name: str, parser: Callable[[str], _T], *, default: _T | None = None +) -> _T | None: + """Get a python value from an environment variable.""" + env_var_value = os.environ.get(var_name, None) + if env_var_value is None: + return default + + try: + return parser(env_var_value) + except Exception as e: + raise RuntimeError( + f"Parsing '{var_name}' (value: '{env_var_value}') environment variable {var_name} failed!" + ) from e + + +@utils.TypeMapping +def _parse_str(type_: type) -> Callable[[str], Any]: + """Default parser: the type string value as is.""" + if issubclass(type_, enum.Enum): + return lambda value: type_[value] # parse enum values from their names + + return lambda x: type_(x) # type: ignore[call-arg] # use type constructor as parser + + +@_parse_str.register(bool) +def _parse_str_as_bool(value: str) -> bool: + match value.strip().upper(): + case "0" | "FALSE" | "OFF": + return False + case "1" | "TRUE" | "ON": + return True + case _: + raise ValueError( + f"{value} cannot be parsed as a boolean value. Use '0 | FALSE | OFF' or '1 | TRUE | ON'." + ) + + +@_parse_str.register(pathlib.Path) +def _parse_str_as_path(value: str) -> pathlib.Path: + expanded = os.path.expandvars(os.path.expanduser(value)) + return pathlib.Path(expanded) + + +def _type_check_validator(type_: type) -> Callable[[Any], None]: + """Generate a validator function that checks if a value is an instance of the given type.""" + + is_instance_checker = utils.isinstancechecker(type_) + + def validator(value: Any) -> None: + if not is_instance_checker(value): + raise TypeError( + f"Expected value of type '{type_}', got type '{type(value)}' (value: {value})" + ) + + return validator + + +class UpdateScope(str, enum.Enum): + """Scope of a configuration option update.""" + + GLOBAL = sys.intern("global") + CONTEXT = sys.intern("context") + + +class OptionUpdateCallback(Protocol[_T_contra]): + """ + Callback invoked after an option changes. + + Callbacks are invoked after both global (via set() or __setattr__) + and context-local (via overrides()) updates. This allows observers + to react to configuration changes. + """ + + def __call__( + self, new_val: _T_contra, old_val: _T_contra | None, scope: UpdateScope + ) -> None: ... + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class OptionDescriptor(Generic[_T]): + """ + Descriptor for a configuration option. + + Instances of this class should be defined as class attributes of a + `ConfigManager` subclass. This class implements the descriptor protocol + to support the bare attribute-style access to the option value on the + manager instance (e.g. `config.debug`), which will be resolved properly + using the precedence rules defined in `ConfigManager.get()`. + + Attributes: + type: The Python type of this configuration option. + default: Initial fallback value for this option. Mutually exclusive with default_factory. + default_factory: Callable to compute the default value given a ConfigManager instance. + Mutually exclusive with default. + validator: Callable that validates the option value, or "type_check" for isinstance checking. + Set to None to disable validation. + update_callback: Optional callback invoked after the option is updated (globally or in context). + env_var_parser: Optional parser for environment variable values. + env_var_prefix: Prefix for the environment variable name. + name: Name of the option (set automatically via __set_name__). + """ + + option_type: type[_T] | Any + default: dataclasses.InitVar[_T | _UnsetSentinel] = UNSET + default_factory: Callable[[ConfigManager], _T] | None = None + validator: Callable[[Any], Any] | Literal["type_check"] | None = "type_check" + update_callback: OptionUpdateCallback[_T] | None = None + env_var_parser: Callable[[str], _T] | None = None + env_var_prefix: str = "GT4PY_" + name: str = dataclasses.field(init=False, default="") + + def __post_init__(self, default: _T | _UnsetSentinel) -> None: + # Initialize the validator + if self.validator == "type_check": + object.__setattr__(self, "validator", _type_check_validator(self.option_type)) + assert self.validator is None or callable(self.validator) + + # Initialize the default factory based on the provided default/default_factory + if not isinstance(default, _UnsetSentinel): + if self.default_factory is not None: + raise ValueError( + "Cannot specify both default and default_factory for a config option descriptor." + ) + if self.validator is not None: + self.validator(default) + object.__setattr__(self, "default_factory", lambda _: default) + elif self.default_factory is None: + raise ValueError( + "Must specify either default or default_factory for a config option descriptor." + ) + + def __set_name__(self, owner: type, name: str) -> None: + """Set the name of the option based on the attribute name in the owner class.""" + object.__setattr__(self, "name", name) + + @overload + def __get__(self, instance: ConfigManager, owner: type[ConfigManager]) -> _T: ... + + @overload + def __get__(self, instance: None, owner: None) -> OptionDescriptor[_T]: ... + + def __get__( + self, instance: ConfigManager | None, owner: type[ConfigManager] | None = None + ) -> _T | OptionDescriptor[_T]: + """ + Get the configuration option value. + + If accessed on the class (instance is None), returns the descriptor itself. + If accessed on an instance, delegates to ConfigManager.get() to get the + effective current value. + """ + try: + assert isinstance(instance, ConfigManager) + return instance.get(self.name) + except Exception as e: + if instance is None: + # Accessed on the class, return the descriptor itself (e.g. for help()) + return self + raise AttributeError(f"Error reading config option {self.name!r}") from e + + def __set__(self, instance: Any, value: _T) -> None: + """ + Set the global value of the configuration option. + + This delegates to ConfigManager.set() which handles global updates and validation. + """ + assert isinstance(instance, ConfigManager) + instance.set(self.name, value) + + @property + def env_var_name(self) -> str: + """Construct the name of the environment variable corresponding to this option.""" + return f"{self.env_var_prefix}{self.name}".upper() + + +class ConfigManager: + """ + Central configuration manager with attribute-style access. + + Config options are defined as `OptionDescriptor` class attributes in a + concrete subclass of `ConfigManager`. The manager stores global values + for all options and allows temporary overrides in a context manager scope. + + The effective value of an option follows this precedence (highest to lowest): + 1. Active context override via the `overrides()` context manager + 2. Global runtime value set via the `set()` method + 3. Default value from the environment variable (if set) + 4. Default value from the descriptor (either `default` or `default_factory`) + + Example: + >>> class MyConfig(ConfigManager): + ... some_option = OptionDescriptor(option_type=str, default="default_value") + >>> config = MyConfig() + >>> config.get("some_option") # Default value from descriptor + 'default_value' + >>> config.set("some_option", "global_value") # Set global value + >>> config.get("some_option") # Apply precedence rules + 'global_value' + >>> with config.overrides(some_option="temporary_override"): # Temporary override + ... config.some_option + 'temporary_override' + """ + + def __init__(self) -> None: + self._descriptors: dict[str, OptionDescriptor[Any]] = { + name: attr + for name, attr in type(self).__dict__.items() + if isinstance(attr, OptionDescriptor) + } + self._keys = set(self._descriptors.keys()) + self._validators: dict[str, Callable[[Any], None]] = { + name: desc.validator + for name, desc in self._descriptors.items() + if callable(desc.validator) + } + self._callbacks: dict[str, OptionUpdateCallback[Any]] = { + name: desc.update_callback + for name, desc in self._descriptors.items() + if desc.update_callback is not None + } + + # An instance-level ContextVar creates isolated context-local state per manager + # instance. Though discouraged in general (values bind to ContextVar identity + # and Context objects hold strong references to ContextVars, so they won't be + # GC'd even if the instance goes out of scope), in this case we really want + # per-registry isolation and we assume only very few ConfigManager instances + # will be ever created. + self._local_context_cvar = contextvars.ContextVar[Mapping[str, Any]]( + f"{self.__class__.__name__}_cvar", default=types.MappingProxyType({}) + ) + + # Option values initialization with environment variable parsing and validation + self._global_context: dict[str, Any] = {} + for name, desc in self._descriptors.items(): + assert desc.default_factory is not None # Guaranteed by __post_init__ + init_value = parse_env_var( + desc.env_var_name, + desc.env_var_parser or _parse_str[desc.option_type], + default=desc.default_factory(self), + ) + if validator := self._validators.get(name): + validator(init_value) + self._global_context[name] = init_value + + def get(self, name: str) -> Any: + """ + Get the effective value of a configuration option. + + Applies precedence rules: context override > global value > environment > default. + + Args: + name: The name of the configuration option. + + Returns: + The effective value of the option. + """ + if name not in self._keys: + raise AttributeError(f"Unrecognized config option: {name}") + if (val := self._local_context_cvar.get().get(name, UNSET)) is not UNSET: + return val + return self._global_context[name] + + def set(self, name: str, value: Any) -> None: + """ + Set the global value of a configuration option. + + Validates the value and invokes any registered callbacks. + + Args: + name: The name of the configuration option. + value: The new value for the option. + """ + if name not in self._keys: + raise AttributeError(f"Unrecognized config option: {name}") + if name in self._local_context_cvar.get(): + raise AttributeError( + f"Cannot set config option {name!r} while it is overridden in a context manager" + ) + if validator := self._validators.get(name): + validator(value) + old_val = self._global_context[name] + self._global_context[name] = value + if callback := self._callbacks.get(name): + callback(value, old_val, UpdateScope.GLOBAL) + + @contextlib.contextmanager + def overrides(self, **overrides: Any) -> Generator[None, None, None]: + """ + Context manager for temporary configuration overrides. + + Overrides are task-local (isolated per thread/async task) and automatically + reverted when exiting the context manager. Nested contexts are supported. + + Args: + **overrides: Configuration option names and their temporary values. + + Example: + >>> with config.overrides(debug=True, verbose_exceptions=True): + ... # Use config with temporary overrides + ... pass + >>> # Overrides are automatically reverted here + """ + if overrides.keys() - self._keys: + raise AttributeError( + f"Unrecognized config options: {set(overrides.keys()) - self._keys}" + ) + + old_values: dict[str, Any] = {} + changes: dict[str, Any] = {} + for name, new_value in overrides.items(): + old_value = self.get(name) + if new_value != old_value: + old_values[name] = old_value + changes[name] = new_value + + for name in changes.keys() & self._validators.keys(): + self._validators[name](changes[name]) + + old_context = self._local_context_cvar.get() + new_context = types.MappingProxyType({**old_context, **changes}) + token = self._local_context_cvar.set(new_context) + + try: + for name in changes.keys() & self._callbacks.keys(): + self._callbacks[name](new_context[name], old_values[name], UpdateScope.CONTEXT) + + yield + + finally: + self._local_context_cvar.reset(token) + + for name in changes.keys() & self._callbacks.keys(): + self._callbacks[name](old_values[name], new_context.get(name), UpdateScope.CONTEXT) + + def as_dict(self) -> dict[str, Any]: + """ + Get the current effective configuration options as a dictionary. + + Returns all configuration options with their effective values, preserving + the order they were defined in the class. + """ + # We use self._descriptors to preserve the order of options as defined in the class. + return {name: self.get(name) for name in self._descriptors.keys()} + + def _option_descriptors_(self) -> types.MappingProxyType[str, OptionDescriptor]: + """ + Get the option descriptors. + + Returns a read-only mapping of option names to their descriptors. + This is useful for introspection and documentation purposes. + """ + return types.MappingProxyType(self._descriptors) + + +def _parse_dump_metrics_filename(value: str) -> bool | pathlib.Path: + try: + return _parse_str[bool](value) + except Exception: + # If parsing as a bool fails, try parsing as a path. + # This allows users to specify a file path or a boolean value for this option. + return _parse_str[pathlib.Path](value) + + +class Config(ConfigManager): + """ + GT4Py configuration manager. + + This class is used to register and manage all configuration options for GT4Py. + All publicly exposed options should be defined here as OptionDescriptor instances. + + Options defined here can be configured via: + - Environment variables (GT4PY_OPTION_NAME format) + - Direct calls to config.set() + - Context manager overrides with config.overrides() + """ + + ## -- Debug options -- + #: Master debug flag. It changes defaults for all the other options to be as helpful + #: for debugging as possible. Environment variable: GT4PY_DEBUG + debug = OptionDescriptor(option_type=bool, default=False) + + #: Verbose flag for DSL compilation errors. Defaults to the value of debug. + #: Environment variable: GT4PY_VERBOSE_EXCEPTIONS + verbose_exceptions = OptionDescriptor[bool]( + option_type=bool, default_factory=(lambda cfg: cast(Config, cfg).debug) + ) + + ## -- Instrumentation options -- + #: User-defined level to enable metrics at lower or equal level. + #: Enabling metrics collection will do extra synchronization and will have + #: impact on runtime performance. Environment variable: GT4PY_COLLECT_METRICS_LEVEL + collect_metrics_level = OptionDescriptor(option_type=int, default=0) + + #: Add GPU trace markers (NVTX, ROC-TX) to the generated code, at compile time. + #: Environment variable: GT4PY_ADD_GPU_TRACE_MARKERS + #: FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. + add_gpu_trace_markers = OptionDescriptor(option_type=bool, default=False) + + #: File path to dump collected metrics at exit, if GT4PY_COLLECT_METRICS_LEVEL is enabled. + #: If set to a True value, it defaults to "gt4py_metrics_YYYYMMDD_HHMMSS.json" in + #: the current folder. + dump_metrics_at_exit = OptionDescriptor( + option_type=bool | pathlib.Path, + default=False, + env_var_parser=_parse_dump_metrics_filename, # type: ignore[arg-type] # mypy gets confused with the overloaded return type of the parser + ) + + ## -- Build options -- + class BuildCacheLifetime(enum.Enum): + """Build cache lifetime modes.""" + + SESSION = "session" + PERSISTENT = "persistent" + + #: Whether generated code projects should be kept around between runs. + #: - SESSION: generated code projects get destroyed when the interpreter shuts down + #: - PERSISTENT: generated code projects are written to build_cache_dir and persist between runs + #: Defaults to PERSISTENT in debug mode, SESSION otherwise. + #: Environment variable: GT4PY_BUILD_CACHE_LIFETIME + build_cache_lifetime = OptionDescriptor[BuildCacheLifetime]( + option_type=BuildCacheLifetime, + default_factory=( + lambda cfg: cast(Config, cfg).BuildCacheLifetime.PERSISTENT + if cast(Config, cfg).debug + else cast(Config, cfg).BuildCacheLifetime.SESSION + ), + ) + + #: Where generated code projects should be persisted when BUILD_CACHE_LIFETIME is PERSISTENT. + #: Supports ~ expansion and environment variable substitution ($VAR, ${VAR}). + #: The actual cache directory will be this path with '/.gt4py_cache' appended. + #: Environment variable: GT4PY_BUILD_CACHE_DIR_ROOT + build_cache_dir_root = OptionDescriptor(option_type=pathlib.Path, default=pathlib.Path.cwd()) + + @property + def build_cache_dir(self) -> pathlib.Path: + assert isinstance(self.build_cache_dir_root, pathlib.Path) + return self.build_cache_dir_root / ".gt4py_cache" + + class CMakeBuildType(enum.Enum): + """CMake build types enum. + + Member values have to be valid CMake syntax. + + Attributes: + DEBUG: Debug build with symbols and no optimization. + RELEASE: Release build with optimization and no symbols. + REL_WITH_DEB_INFO: Release build with optimization and debug symbols. + MIN_SIZE_REL: Release build optimized for minimal size. + """ + + DEBUG = "Debug" + RELEASE = "Release" + REL_WITH_DEB_INFO = "RelWithDebInfo" + MIN_SIZE_REL = "MinSizeRel" + + #: Build type to be used when CMake is used to compile generated code. + #: Defaults to DEBUG in debug mode, RELEASE otherwise. + #: Might have no effect when CMake is not used as part of the toolchain. + #: Environment variable: GT4PY_CMAKE_BUILD_TYPE + #: FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. + cmake_build_type = OptionDescriptor[CMakeBuildType]( + option_type=CMakeBuildType, + default_factory=( + lambda cfg: cast(Config, cfg).CMakeBuildType.DEBUG + if cast(Config, cfg).debug + else cast(Config, cfg).CMakeBuildType.RELEASE + ), + ) + + #: Number of threads to use for compilation (0 = synchronous compilation). + #: Default behavior: + #: - Uses os.cpu_count() if available (TODO: Python >= 3.13 use process_cpu_count()) + #: - Falls back to 1 if os.cpu_count() returns None + #: - Caps the value at 32 to avoid excessive resource usage on HPC systems + #: Environment variable: GT4PY_BUILD_JOBS + build_jobs = OptionDescriptor( + option_type=int, + default_factory=lambda ctx: min(os.cpu_count() or 1, 32), + ) + + ## -- Code-generation options -- + #: Experimental, use at your own risk: assume horizontal dimension has stride 1 + #: Environment variable: GT4PY_UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE + #: FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. + unstructured_horizontal_has_unit_stride = OptionDescriptor(option_type=bool, default=False) + + #: The default for whether to allow jit-compilation for a compiled program. + #: This default can be overridden per program via their respective APIs. + #: Environment variable: GT4PY_ENABLE_JIT_DEFAULT + enable_jit_default = OptionDescriptor(option_type=bool, default=True) + + +#: Global singleton instance of the GT4Py configuration manager. +#: Use this to access and modify configuration options: config.debug, config.set(...), etc. +config = Config() diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py deleted file mode 100644 index 0e6bf4ff5d..0000000000 --- a/src/gt4py/next/config.py +++ /dev/null @@ -1,142 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import datetime -import enum -import os -import pathlib -from typing import Final - - -class BuildCacheLifetime(enum.Enum): - SESSION = 1 - PERSISTENT = 2 - - -class CMakeBuildType(enum.Enum): - """ - CMake build types enum. - - Member values have to be valid CMake syntax. - """ - - DEBUG = "Debug" - RELEASE = "Release" - REL_WITH_DEB_INFO = "RelWithDebInfo" - MIN_SIZE_REL = "MinSizeRel" - - -def env_flag_to_bool(name: str, default: bool) -> bool: - """Convert environment variable string variable to a bool value.""" - flag_value = os.environ.get(name, None) - if flag_value is None: - return default - match flag_value.lower(): - case "0" | "false" | "off": - return False - case "1" | "true" | "on": - return True - case _: - raise ValueError( - "Invalid GT4Py environment flag value: use '0 | false | off' or '1 | true | on'." - ) - - -def env_flag_to_int(name: str, default: int) -> int: - """Convert environment variable string variable to an int value.""" - flag_value = os.environ.get(name, None) - if flag_value is None: - return default - try: - return int(flag_value) - except ValueError: - raise ValueError( - f"Invalid GT4Py environment flag value: {flag_value} is not an integer." - ) from None - - -#: Master debug flag -#: Changes defaults for all the other options to be as helpful for debugging as possible. -#: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=False) - - -#: Verbose flag for DSL compilation errors -VERBOSE_EXCEPTIONS: bool = env_flag_to_bool( - "GT4PY_VERBOSE_EXCEPTIONS", default=True if DEBUG else False -) - - -#: Where generated code projects should be persisted. -#: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT -BUILD_CACHE_DIR: pathlib.Path = ( - pathlib.Path(os.environ.get("GT4PY_BUILD_CACHE_DIR", pathlib.Path.cwd())) / ".gt4py_cache" -) - - -#: Whether generated code projects should be kept around between runs. -#: - SESSION: generated code projects get destroyed when the interpreter shuts down -#: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs -BUILD_CACHE_LIFETIME: BuildCacheLifetime = BuildCacheLifetime[ - os.environ.get("GT4PY_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() -] - -#: Build type to be used when CMake is used to compile generated code. -#: Might have no effect when CMake is not used as part of the toolchain. -# FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. -CMAKE_BUILD_TYPE: CMakeBuildType = CMakeBuildType[ - os.environ.get("GT4PY_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() -] - -#: Experimental, use at your own risk: assume horizontal dimension has stride 1 -# FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. -UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE: bool = env_flag_to_bool( - "GT4PY_UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE", default=False -) - -#: Add GPU trace markers (NVTX, ROC-TX) to the generated code, at compile time. -# FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. -ADD_GPU_TRACE_MARKERS: bool = env_flag_to_bool("GT4PY_ADD_GPU_TRACE_MARKERS", default=False) - -#: Number of threads to use to use for compilation (0 = synchronous compilation). -#: Default: -#: - use os.cpu_count(), TODO(havogt): in Python >= 3.13 use `process_cpu_count()` -#: - if os.cpu_count() is None we are conservative and use 1 job, -#: - if the number is huge (e.g. HPC system) we limit to a smaller number -BUILD_JOBS: int = int(os.environ.get("GT4PY_BUILD_JOBS", min(os.cpu_count() or 1, 32))) - -#: User-defined level to enable metrics at lower or equal level. -#: Enabling metrics collection will do extra synchronization and will have -#: impact on runtime performance. -COLLECT_METRICS_LEVEL: int = env_flag_to_int("GT4PY_COLLECT_METRICS_LEVEL", default=0) - - -#: File path to dump collected metrics at exit, if COLLECT_METRICS_LEVEL is enabled. -#: If set to a True value, it defaults to "gt4py_metrics_YYYYMMDD_HHMMSS.json" in -#: the current folder. -DUMP_METRICS_AT_EXIT: str | None = None - - -def _init_dump_metrics_filename() -> str: - return f"gt4py_metrics_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json" - - -_dump_metrics_at_exit_env = os.environ.get("GT4PY_DUMP_METRICS_AT_EXIT", None) -if _dump_metrics_at_exit_env is not None: - try: - if env_flag_to_bool("GT4PY_DUMP_METRICS_AT_EXIT", default=False): - DUMP_METRICS_AT_EXIT = _init_dump_metrics_filename() - except ValueError: - DUMP_METRICS_AT_EXIT = _dump_metrics_at_exit_env - - -#: The default for whether to allow jit-compilation for a compiled program. -#: This default can be overriden per program. -ENABLE_JIT_DEFAULT: bool = env_flag_to_bool("GT4PY_ENABLE_JIT_DEFAULT", default=True) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index fa99f5fabd..2a22c85283 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -994,7 +994,8 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR +# TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[has-type] # mypy bug? mypy cannot see experimental.concat_where type here def _make_reduction( diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 6bd084eebd..dffc2de321 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -44,7 +44,7 @@ def compilation_error_hook( """ # in hard crashes of the interpreter, the `exceptions` module might be partially unloaded if exceptions.DSLError is not None and isinstance(value, exceptions.DSLError): - exc_strs = _format_uncaught_error(value, config.VERBOSE_EXCEPTIONS) + exc_strs = _format_uncaught_error(value, config.verbose_exceptions) print("".join(exc_strs), file=sys.stderr) else: fallback(type_, value, tb) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 073e3ff95b..e5319601fc 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -20,7 +20,7 @@ import typing import warnings from collections.abc import Callable -from typing import Any, Generic, Optional, Sequence, TypeAlias +from typing import Any, Final, Generic, Optional, Sequence, TypeAlias from gt4py import eve from gt4py._core import definitions as core_defs @@ -54,7 +54,7 @@ DEFAULT_BACKEND: next_backend.Backend | None = None -ProgramCallMetricsCollector = metrics.make_collector( +ProgramCallMetricsCollector: Final[type[metrics.BaseMetricsCollector]] = metrics.make_collector( # type: ignore[has-type] # mypy bug? mypy cannot see metrics.make_collector type here level=metrics.MINIMAL, metric_name=metrics.TOTAL_METRIC ) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 6b19e7cc1f..39ac4b32d8 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -141,7 +141,7 @@ def past_to_gtir(inp: ConcretePASTProgramDef) -> definitions.CompilableProgramDe inp.args, args=args, kwargs=kwargs, column_axis=_column_axis(all_closure_vars) ) - if config.DEBUG or inp.data.debug: + if config.debug or inp.data.debug: devtools.debug(itir_program) return definitions.CompilableProgramDef(data=itir_program, args=compile_time_args) diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 59027cff04..955d42bae6 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -29,7 +29,7 @@ import numpy as np from gt4py.eve import extended_typing as xtyping, utils -from gt4py.eve.extended_typing import Any, Final +from gt4py.eve.extended_typing import Any, Final, assert_never from gt4py.next import config from gt4py.next.otf import arguments @@ -53,17 +53,17 @@ def is_any_level_enabled() -> bool: """Check if any metrics collection level is enabled.""" - return config.COLLECT_METRICS_LEVEL > DISABLED + return config.collect_metrics_level > DISABLED def is_level_enabled(level: int) -> bool: """Check if a given metrics collection level is enabled.""" - return config.COLLECT_METRICS_LEVEL >= level + return config.collect_metrics_level >= level def get_current_level() -> int: """Retrieve the current metrics collection level (from the configuration module).""" - return config.COLLECT_METRICS_LEVEL + return config.collect_metrics_level @dataclasses.dataclass(frozen=True) @@ -441,17 +441,31 @@ def dump_json( pathlib.Path(filename).write_text(dumps_json(metric_sources)) +def _init_dump_metrics_filename() -> pathlib.Path: + return pathlib.Path(f"gt4py_metrics_{time.strftime('%Y%m%d_%H%M%S')}.json") + + # Handler registration to automatically dump metrics at program exit if # the corresponding configuration flag is set. def _dump_metrics_at_exit() -> None: """Dump collected metrics to a file at program exit if required.""" # It is assumed that 'gt4py.next.config' is still alive at this point - if config.DUMP_METRICS_AT_EXIT and (is_any_level_enabled() or sources): + match config.dump_metrics_at_exit: + case False: + metrics_dump_file = None + case True: + metrics_dump_file = _init_dump_metrics_filename() + case pathlib.Path() as user_path: + metrics_dump_file = user_path + case _ as unreachable: + assert_never(unreachable) + + if metrics_dump_file is not None and (is_any_level_enabled() or sources): try: - pathlib.Path(config.DUMP_METRICS_AT_EXIT).write_text(dumps_json()) + metrics_dump_file.write_text(dumps_json()) print( - f"--- atexit: GT4Py performance metrics saved at {config.DUMP_METRICS_AT_EXIT} ---", + f"--- atexit: GT4Py performance metrics saved at {metrics_dump_file} ---", file=sys.stderr, ) except Exception as e: diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b2f742b4f5..72748fb533 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1630,7 +1630,7 @@ def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProvider ) -@runtime.set_at.register(EMBEDDED) +@runtime.set_at.register(EMBEDDED) # type: ignore[has-type] # mypy bug? mypy cannot see runtime.set_at type here def set_at( expr: common.Field, domain_like: xtyping.MaybeNestedInTuple[common.DomainLike], @@ -1640,12 +1640,12 @@ def set_at( operators._tuple_assign_field(target, expr, domain) -@runtime.get_domain_range.register(EMBEDDED) +@runtime.get_domain_range.register(EMBEDDED) # type: ignore[has-type] # mypy bug? mypy cannot see runtime.get_domain_range type here def get_domain_range(field: common.Field, dim: common.Dimension) -> tuple[int, int]: return (field.domain[dim].unit_range.start, field.domain[dim].unit_range.stop) -@runtime.if_stmt.register(EMBEDDED) +@runtime.if_stmt.register(EMBEDDED) # type: ignore[has-type] # mypy bug? mypy cannot see runtime.if_stmt type here def if_stmt(cond: bool, true_branch: Callable[[], None], false_branch: Callable[[], None]) -> None: """ (Stateful) if statement. @@ -1665,7 +1665,7 @@ def if_stmt(cond: bool, true_branch: Callable[[], None], false_branch: Callable[ false_branch() -@runtime.temporary.register(EMBEDDED) +@runtime.temporary.register(EMBEDDED) # type: ignore[has-type] # mypy bug? mypy cannot see runtime.temporary type here def temporary(domain: runtime.CartesianDomain | runtime.UnstructuredDomain, dtype): type_ = runtime._dtypebuiltin_to_ts(dtype) new_domain = common.domain(domain) diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index 88a466229f..c02610472d 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -67,7 +67,7 @@ def itir(self, *args): fencil_definition = trace_fencil_definition(self.definition, args) - if config.DEBUG: + if config.debug: devtools.debug(fencil_definition) return fencil_definition diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 416ac6a7c0..beb7382c83 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -209,7 +209,7 @@ def make_argument(name: str, type_: ts.TypeSpec) -> str | BufferSID | Tuple: name=dim.value, static_stride=1 if ( - config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE + config.unstructured_horizontal_has_unit_stride and dim.kind == common.DimensionKind.HORIZONTAL ) else None, diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py index 1b79cad6e4..374f2b88e8 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py @@ -13,7 +13,7 @@ import pathlib import subprocess import warnings -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar from gt4py._core import definitions as core_defs from gt4py.next import config, errors @@ -22,6 +22,10 @@ from gt4py.next.otf.compilation.build_systems import cmake_lists +if TYPE_CHECKING: + from gt4py.next import config_type + + def get_device_arch() -> str | None: if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.CUDA: # use `cp` from core_defs to avoid trying to re-import cupy @@ -69,13 +73,13 @@ class CMakeFactory( """Create a CMakeProject from a ``CompilableSource`` stage object with given CMake settings.""" cmake_generator_name: str = "Ninja" - cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG + cmake_build_type: config_type.CMakeBuildType = config.CMakeBuildType.DEBUG cmake_extra_flags: list[str] = dataclasses.field(default_factory=list) def __call__( self, source: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - cache_lifetime: config.BuildCacheLifetime, + cache_lifetime: config_type.BuildCacheLifetime, ) -> CMakeProject: if not source.binding_source: raise NotImplementedError( @@ -128,7 +132,7 @@ class CMakeProject(stages.BuildSystemProject[CPPLikeCodeSpecT, code_specs.Python source_files: dict[str, str] program_name: str generator_name: str = "Ninja" - build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG + build_type: config_type.CMakeBuildType = config.CMakeBuildType.DEBUG extra_cmake_flags: list[str] = dataclasses.field(default_factory=list) def build(self) -> None: diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 347b0e25e9..4c13307c4a 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -14,7 +14,7 @@ import re import shutil import subprocess -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar from gt4py._core import locking from gt4py.next import config, errors @@ -24,6 +24,9 @@ from gt4py.next.otf.compilation.build_systems import cmake +if TYPE_CHECKING: + from gt4py.next import config_type + CPPLikeCodeSpecT = TypeVar("CPPLikeCodeSpecT", bound=code_specs.CPPLikeCodeSpec) @@ -39,14 +42,14 @@ class CompiledbFactory( and library dependencies. """ - cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG + cmake_build_type: config_type.CMakeBuildType = config.CMakeBuildType.DEBUG cmake_extra_flags: list[str] = dataclasses.field(default_factory=list) renew_compiledb: bool = False def __call__( self, source: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - cache_lifetime: config.BuildCacheLifetime, + cache_lifetime: config_type.BuildCacheLifetime, ) -> CompiledbProject: if not source.binding_source: raise NotImplementedError( @@ -244,7 +247,7 @@ def _cc_prototype_program_name( def _cc_prototype_program_source( deps: tuple[interface.LibraryDependency, ...], - build_type: config.CMakeBuildType, + build_type: config_type.CMakeBuildType, cmake_flags: list[str], code_spec: code_specs.CPPLikeCodeSpec, ) -> stages.ProgramSource: @@ -260,9 +263,9 @@ def _cc_prototype_program_source( def _cc_get_compiledb( renew_compiledb: bool, prototype_program_source: stages.ProgramSource, - build_type: config.CMakeBuildType, + build_type: config_type.CMakeBuildType, cmake_flags: list[str], - cache_lifetime: config.BuildCacheLifetime, + cache_lifetime: config_type.BuildCacheLifetime, ) -> pathlib.Path: cache_path = cache.get_cache_folder( stages.CompilableProject(prototype_program_source, None), cache_lifetime @@ -293,9 +296,9 @@ def _cc_find_compiledb(path: pathlib.Path) -> Optional[pathlib.Path]: def _cc_create_compiledb( prototype_program_source: stages.ProgramSource, - build_type: config.CMakeBuildType, + build_type: config_type.CMakeBuildType, cmake_flags: list[str], - cache_lifetime: config.BuildCacheLifetime, + cache_lifetime: config_type.BuildCacheLifetime, ) -> pathlib.Path: prototype_project = cmake.CMakeFactory( cmake_generator_name="Ninja", diff --git a/src/gt4py/next/otf/compilation/cache.py b/src/gt4py/next/otf/compilation/cache.py index 6ce2ba6eac..ebbe4b8601 100644 --- a/src/gt4py/next/otf/compilation/cache.py +++ b/src/gt4py/next/otf/compilation/cache.py @@ -8,15 +8,22 @@ """Caching for compiled backend artifacts.""" +from __future__ import annotations + import hashlib import pathlib import tempfile +from typing import TYPE_CHECKING from gt4py.next import config from gt4py.next.otf import stages from gt4py.next.otf.binding import interface +if TYPE_CHECKING: + from gt4py.next import config_type + + _session_cache_dir = tempfile.TemporaryDirectory(prefix="gt4py_session_") _session_cache_dir_path = pathlib.Path(_session_cache_dir.name) @@ -50,7 +57,7 @@ def _cache_folder_name(source: stages.ProgramSource) -> str: def get_cache_folder( - compilable_source: stages.CompilableProject, lifetime: config.BuildCacheLifetime + compilable_source: stages.CompilableProject, lifetime: config_type.BuildCacheLifetime ) -> pathlib.Path: """ Construct the path to where the build system project artifact of a compilable source should be cached. @@ -64,7 +71,7 @@ def get_cache_folder( case config.BuildCacheLifetime.SESSION: base_path = _session_cache_dir_path case config.BuildCacheLifetime.PERSISTENT: - base_path = config.BUILD_CACHE_DIR + base_path = config.build_cache_dir case _: raise ValueError("Unsupported caching lifetime.") diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 3748d95192..5877f8d06f 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -10,16 +10,19 @@ import dataclasses import pathlib -from typing import Protocol, TypeVar +from typing import TYPE_CHECKING, Protocol, TypeVar import factory from gt4py._core import locking -from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import build_data, cache, importer +if TYPE_CHECKING: + from gt4py.next import config_type + + T = TypeVar("T") @@ -40,7 +43,7 @@ class BuildSystemProjectGenerator(Protocol[CodeSpecT, TargetCodeSpecT]): def __call__( self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], - cache_lifetime: config.BuildCacheLifetime, + cache_lifetime: config_type.BuildCacheLifetime, ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... @@ -58,7 +61,7 @@ class Compiler( ): """Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" - cache_lifetime: config.BuildCacheLifetime + cache_lifetime: config_type.BuildCacheLifetime builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] force_recompile: bool = False diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 64cdef3fd8..0a4bb731c9 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -149,9 +149,9 @@ def compiled_program_call_context( def _init_async_compilation_pool() -> None: global _async_compilation_pool - if _async_compilation_pool is None and config.BUILD_JOBS > 0: + if _async_compilation_pool is None and config.build_jobs > 0: _async_compilation_pool = concurrent.futures.ThreadPoolExecutor( - max_workers=config.BUILD_JOBS + max_workers=config.build_jobs ) diff --git a/src/gt4py/next/otf/options.py b/src/gt4py/next/otf/options.py index 4f77d44586..48672bb36b 100644 --- a/src/gt4py/next/otf/options.py +++ b/src/gt4py/next/otf/options.py @@ -25,7 +25,7 @@ class CompilationOptions: #: to `compile` before calling. # Uses a factory to make changes to the config after module import time take effect. This is # mostly important for testing. Users should not rely on it. - enable_jit: bool = dataclasses.field(default_factory=lambda: config.ENABLE_JIT_DEFAULT) + enable_jit: bool = dataclasses.field(default_factory=lambda: config.enable_jit_default) #: If the user requests static params, they will be used later to initialize CompiledPrograms. #: By default the set of static params is set when compiling for the first time, e.g. on call diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index eb84de9185..cabd4eb657 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -709,7 +709,7 @@ def add_nested_sdfg( if dataname in data_args: # Uninitialized arguments should not be used inside the nested SDFG. if (arg_node := data_args[dataname]) is None: - inner_ctx.sdfg.remove_data(dataname, validate=gtx_config.DEBUG) + inner_ctx.sdfg.remove_data(dataname, validate=gtx_config.debug) else: input_memlets[dataname] = outer_ctx.sdfg.make_array_memlet( arg_node.dc_node.data diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py b/src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py index 8052426f33..c6266d806e 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py @@ -529,7 +529,7 @@ def _process_descending_points_of_state( # nested SDFG and also delete its alias inside it. _cleanup_memlet_path(state, descending_point) descending_point.consumer.remove_in_connector(descending_point.edge.dst_conn) - nsdfg.sdfg.remove_data(descending_point.edge.dst_conn, validate=gtx_config.DEBUG) + nsdfg.sdfg.remove_data(descending_point.edge.dst_conn, validate=gtx_config.debug) return nb_applies diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index ac3a8b757a..0be84b53ea 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -116,7 +116,7 @@ def make_dace_backend( # Set `unit_strides_kind` based on the gt4py env configuration. optimization_args = optimization_args | { "unit_strides_kind": common.DimensionKind.HORIZONTAL - if config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE + if config.unstructured_horizontal_has_unit_stride else None } diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/common.py b/src/gt4py/next/program_processors/runners/dace/workflow/common.py index cfb0d23596..5e193db3a6 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/common.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/common.py @@ -6,9 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import contextlib import os -from typing import Any, Final, Generator, Optional +from typing import TYPE_CHECKING, Any, Final, Generator, Optional import dace @@ -16,6 +18,10 @@ from gt4py.next import config +if TYPE_CHECKING: + from gt4py.next import config_type + + SDFG_ARG_METRIC_LEVEL: Final[str] = "gt_metrics_level" """Name of SDFG argument to input the GT4Py metrics level.""" @@ -26,7 +32,7 @@ def set_dace_config( device_type: core_defs.DeviceType, - cmake_build_type: Optional[config.CMakeBuildType] = None, + cmake_build_type: Optional[config_type.CMakeBuildType] = None, ) -> None: """Set the DaCe configuration as required by GT4Py. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index e1747b7ac3..c9b015044f 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -12,7 +12,7 @@ import os import warnings from collections.abc import Callable, MutableSequence, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any import dace import factory @@ -24,6 +24,10 @@ from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon +if TYPE_CHECKING: + from gt4py.next import config_type + + class CompiledDaceProgram: sdfg_program: dace.CompiledSDFG @@ -96,7 +100,7 @@ def fast_call(self) -> None: "Argument vector was not set properly." ) self.sdfg_program.fast_call( - self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=config.DEBUG + self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=config.debug ) def __call__(self, **kwargs: Any) -> None: @@ -129,9 +133,9 @@ class DaCeCompiler( """Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" bind_func_name: str - cache_lifetime: config.BuildCacheLifetime + cache_lifetime: config_type.BuildCacheLifetime device_type: core_defs.DeviceType - cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG + cmake_build_type: config_type.CMakeBuildType = config.CMakeBuildType.DEBUG def __call__( self, diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 103e7af33b..ad96ddd1fa 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -64,7 +64,7 @@ def decorated_program( filter_args=False, ) this_call_args |= { - gtx_wfdcommon.SDFG_ARG_METRIC_LEVEL: config.COLLECT_METRICS_LEVEL, + gtx_wfdcommon.SDFG_ARG_METRIC_LEVEL: config.collect_metrics_level, gtx_wfdcommon.SDFG_ARG_METRIC_COMPUTE_TIME: collect_time_arg, } fun.construct_arguments(**this_call_args) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 62febd0965..5fbcb24676 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -9,7 +9,7 @@ from __future__ import annotations import functools -from typing import Final +from typing import TYPE_CHECKING, Final import factory @@ -28,6 +28,9 @@ ) +if TYPE_CHECKING: + from gt4py.next import config_type + _GT_DACE_BINDING_FUNCTION_NAME: Final[str] = "update_sdfg_args" @@ -38,8 +41,8 @@ class Meta: class Params: auto_optimize: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough - lambda: config.CMAKE_BUILD_TYPE + cmake_build_type: config_type.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factoryboy's type stubs seem incomplete + lambda: config.cmake_build_type ) cached_translation = factory.Trait( @@ -47,7 +50,7 @@ class Params: lambda o: workflow.CachedStep( o.bare_translation, hash_function=stages.fingerprint_compilable_program, - cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "translation_cache")), + cache=filecache.FileCache(str(config.build_cache_dir / "translation_cache")), ) ), ) @@ -68,7 +71,7 @@ class Params: compilation = factory.SubFactory( DaCeCompilationStepFactory, bind_func_name=_GT_DACE_BINDING_FUNCTION_NAME, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), + cache_lifetime=factory.LazyFunction(lambda: config.build_cache_lifetime), device_type=factory.SelfAttribute("..device_type"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 7df853760e..fe8536c922 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -38,7 +38,7 @@ def find_constant_symbols( """Helper function to find symbols to replace with constant values.""" constant_symbols: dict[str, int] = {} - if config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE: + if config.unstructured_horizontal_has_unit_stride: # Search the stride symbols corresponding to the horizontal dimension for p in ir.params: if isinstance(p.type, ts.FieldType): @@ -267,7 +267,7 @@ def add_instrumentation(sdfg: dace.SDFG, gpu: bool) -> None: dace.Memlet(f"{output}[0]"), ) - if gpu and _has_gpu_schedule(sdfg) and config.ADD_GPU_TRACE_MARKERS: + if gpu and _has_gpu_schedule(sdfg) and config.add_gpu_trace_markers: sdfg.instrument = dace.dtypes.InstrumentationType.GPU_TX_MARKERS for node, _ in sdfg.all_nodes_recursive(): if isinstance( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index c1743dea6a..00bef6d9f8 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -6,8 +6,10 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import functools -from typing import Any +from typing import TYPE_CHECKING, Any import factory import numpy as np @@ -25,6 +27,10 @@ from gt4py.next.program_processors.codegens.gtfn import gtfn_module +if TYPE_CHECKING: + from gt4py.next import config_type + + def convert_arg(arg: Any) -> Any: # Note: this function is on the hot path and needs to have minimal overhead. if (origin := getattr(arg, "__gt_origin__", None)) is not None: @@ -112,8 +118,8 @@ class Meta: class Params: device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough - lambda: config.CMAKE_BUILD_TYPE + cmake_build_type: config_type.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough + lambda: config.cmake_build_type ) builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( # type: ignore[assignment] # factory-boy typing not precise enough lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) @@ -124,7 +130,7 @@ class Params: lambda o: workflow.CachedStep( o.bare_translation, hash_function=stages.fingerprint_compilable_program, - cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), + cache=filecache.FileCache(str(config.build_cache_dir / "gtfn_cache")), ) ), ) @@ -137,11 +143,11 @@ class Params: translation = factory.LazyAttribute(lambda o: o.bare_translation) bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] = ( - nanobind.bind_source + nanobind.bind_source # type: ignore[has-type] # mypy bug? mypy cannot see nanobind.bind_source type here ) compilation = factory.SubFactory( compiler.CompilerFactory, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), + cache_lifetime=factory.LazyFunction(lambda: config.build_cache_lifetime), builder_factory=factory.SelfAttribute("..builder_factory"), ) decoration = factory.LazyAttribute( diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 5ee0a67f25..9280a37922 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -220,7 +220,7 @@ class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, stages.Execu transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` def __call__(self, inp: definitions.CompilableProgramDef) -> stages.ExecutableProgram: - debug = config.DEBUG if self.debug is None else self.debug + debug = config.debug if self.debug is None else self.debug fencil = fencil_generator( inp.data, diff --git a/tests/eve_tests/unit_tests/test_utils.py b/tests/eve_tests/unit_tests/test_utils.py index ae8e938396..191afd3fb2 100644 --- a/tests/eve_tests/unit_tests/test_utils.py +++ b/tests/eve_tests/unit_tests/test_utils.py @@ -285,6 +285,172 @@ def func(a, b, c): assert fp3() == 6 +class TestTypeMapping: + """Unit tests for TypeMapping class.""" + + def test_basic_getitem(self): + """Test basic type-to-value mapping retrieval.""" + from gt4py.eve.utils import TypeMapping + + def fallback(type_): + return f"default_{type_.__name__}" + + mapping = TypeMapping(fallback) + + # Register some types + mapping[int] = "integer" + mapping[str] = "string" + + assert mapping[int] == "integer" + assert mapping[str] == "string" + + def test_fallback_factory(self): + """Test that fallback factory is used for unregistered types.""" + from gt4py.eve.utils import TypeMapping + + def fallback(type_): + return f"default_{type_.__name__}" + + mapping = TypeMapping(fallback) + mapping[int] = "integer" + + # Unregistered type should use fallback + assert mapping[str] == "default_str" + assert mapping[float] == "default_float" + assert mapping[list] == "default_list" + + def test_setitem_and_register(self): + """Test both __setitem__ and register methods.""" + from gt4py.eve.utils import TypeMapping + + mapping = TypeMapping(lambda t: None) + + # Using __setitem__ + mapping[int] = "via_setitem" + assert mapping[int] == "via_setitem" + + # Using register method + result = mapping.register(str, "via_register") + assert result == "via_register" + assert mapping[str] == "via_register" + + # Using register method in a decorator style + result = mapping.register(str)("via_register_decorator") + assert result == "via_register_decorator" + assert mapping[str] == "via_register_decorator" + + def test_callable_value_registration(self): + """Test registering callable objects as values.""" + from gt4py.eve.utils import TypeMapping + + def int_handler(a): + return f"int_handler: {a}" + + def str_handler(a): + return f"str_handler: {a}" + + def any_handler(a): + return f"{type(a).__name__}_handler: {a}" + + mapping = TypeMapping(lambda t: any_handler) + mapping[int] = int_handler + mapping[str] = str_handler + + assert callable(mapping[int]) + assert callable(mapping[str]) + assert mapping[int](1) == "int_handler: 1" + assert mapping[str](2) == "str_handler: 2" + assert mapping[tuple]((3, 4)) == "tuple_handler: (3, 4)" + + def test_multiple_type_registrations(self): + """Test registering and retrieving multiple types.""" + from gt4py.eve.utils import TypeMapping + + mapping = TypeMapping(lambda t: f"fallback_{t.__name__}") + + types_values = { + int: "number", + str: "text", + float: "decimal", + list: "sequence", + dict: "mapping", + set: "unique", + tuple: "immutable", + } + + for type_, value in types_values.items(): + mapping[type_] = value + + for type_, expected_value in types_values.items(): + assert mapping[type_] == expected_value + + def test_overwrite_registration(self): + """Test that re-registering a type overwrites the previous value.""" + from gt4py.eve.utils import TypeMapping + + mapping = TypeMapping(lambda t: None) + + mapping[int] = "first" + assert mapping[int] == "first" + + mapping[int] = "second" + assert mapping[int] == "second" + + mapping[int] = "third" + assert mapping[int] == "third" + + def test_subclass_dispatch(self): + """Test that singledispatch works with subclasses.""" + from gt4py.eve.utils import TypeMapping + + mapping = TypeMapping(lambda t: "default") + + class BaseClass: + pass + + class SubClass(BaseClass): + pass + + mapping[BaseClass] = "base" + + # Subclass should dispatch to BaseClass handler + assert mapping[SubClass] == "base" + assert mapping[BaseClass] == "base" + + def test_complex_fallback_factory(self): + """Test TypeMapping with a complex fallback factory function.""" + from gt4py.eve.utils import TypeMapping + + def complex_fallback(type_): + if hasattr(type_, "__len__"): + return f"sized_{type_.__name__}" + else: + return f"unsized_{type_.__name__}" + + mapping = TypeMapping(complex_fallback) + + assert "sized" in mapping[str] + assert "sized" in mapping[list] + assert "sized" in mapping[dict] + assert "unsized" in mapping[float] + assert "unsized" in mapping[int] + + def test_iteration(self): + """Test iteration over registered types.""" + from gt4py.eve.utils import TypeMapping + + mapping = TypeMapping(lambda t: None) + types_to_register = [int, str, float, list, dict] + + for i, type_ in enumerate(types_to_register): + mapping[type_] = f"value_{i}" + + # Check that all registered types are in iteration + registered_types = list(mapping) + for type_ in types_to_register: + assert type_ in registered_types + + def test_noninstantiable_class(): @eve.utils.noninstantiable class NonInstantiableClass(eve.datamodels.DataModel): diff --git a/tests/next_tests/__init__.py b/tests/next_tests/__init__.py index ba174aa6d1..b461b06153 100644 --- a/tests/next_tests/__init__.py +++ b/tests/next_tests/__init__.py @@ -16,9 +16,10 @@ __all__ = ["definitions", "get_processor_id"] -if config.BUILD_CACHE_LIFETIME is config.BuildCacheLifetime.PERSISTENT: +if config.build_cache_lifetime is config.BuildCacheLifetime.PERSISTENT: warnings.warn( - "You are running GT4Py tests with BUILD_CACHE_LIFETIME set to PERSISTENT!", UserWarning + "You are running GT4Py tests with 'config.build_cache_lifetime' set to PERSISTENT!", + UserWarning, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index 4578576f02..2db15da661 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -60,7 +60,7 @@ def testee(a: cases.IField, out: cases.IField): testee_op(a, a, out=out) with ( - mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics_level), + mock.patch("gt4py.next.config.collect_metrics_level", metrics_level), mock.patch( "gt4py.next.instrumentation.metrics.sources", collections.defaultdict(metrics.Source) ), diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py index 865d69b4e4..2dc7836034 100644 --- a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py @@ -59,7 +59,7 @@ def test_set_current_source_key_different_key_raises(self): class TestSourceKeyContextManager: def test_context_manager_sets_and_resets_key(self): - with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.MINIMAL): + with gt_config.overrides(collect_metrics_level=metrics.MINIMAL): metrics._source_key_cvar.set( metrics._NO_KEY_SET_MARKER_ ) # Reset context variable before test @@ -79,7 +79,7 @@ def test_context_manager_sets_and_resets_key(self): ) def test_context_manager_with_no_key(self): - with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.MINIMAL): + with gt_config.overrides(collect_metrics_level=metrics.MINIMAL): metrics._source_key_cvar.set("__BEFORE__MARKER__") # Reset context variable before test with metrics.SourceKeyContextManager(): @@ -93,7 +93,7 @@ def test_context_manager_with_no_key(self): assert metrics._source_key_cvar.get(metrics._NO_KEY_SET_MARKER_) == "__BEFORE__MARKER__" def test_context_manager_nested(self): - with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.MINIMAL): + with gt_config.overrides(collect_metrics_level=metrics.MINIMAL): metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) key1 = "outer_key" key2 = "inner_key" @@ -122,7 +122,7 @@ class TestCollector( ): ... metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) - with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.MINIMAL): + with gt_config.overrides(collect_metrics_level=metrics.MINIMAL): outer_key = "outer_key" metrics.set_current_source_key("outer_key") assert metrics.get_current_source_key() == outer_key @@ -141,7 +141,7 @@ class TestCollector( key = "test_disabled" metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) - with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.DISABLED): + with gt_config.overrides(collect_metrics_level=metrics.DISABLED): metrics.set_current_source_key(key) with TestCollector(key=key): @@ -162,7 +162,7 @@ class CustomCollector( key = "test_custom" metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) - with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.PERFORMANCE): + with gt_config.overrides(collect_metrics_level=metrics.PERFORMANCE): with CustomCollector(key=key): pass @@ -430,7 +430,7 @@ def test_dump_json(sample_source_metrics: Mapping[str, metrics.Source], tmp_path class TestDumpMetricsAtExit: - @pytest.mark.parametrize("mode", ["explicit", "auto", None]) + @pytest.mark.parametrize("mode", ["explicit", "auto", False]) def test_dump_metrics_at_exit_enabled( self, sample_source_metrics: Mapping[str, metrics.Source], @@ -438,29 +438,29 @@ def test_dump_metrics_at_exit_enabled( mode: str | None, ): """Test _dump_metrics_at_exit writes to a file when enabled.""" - explicit_output_filename = str(tmp_path / "explicit_metrics.json") - auto_output_filename = str(tmp_path / gt_config._init_dump_metrics_filename()) + explicit_output_filename = tmp_path / "explicit_metrics.json" + auto_output_filename = tmp_path / metrics._init_dump_metrics_filename() if mode == "explicit": output_filename = explicit_output_filename elif mode == "auto": output_filename = auto_output_filename else: - output_filename = None + output_filename = False - with unittest.mock.patch("gt4py.next.config.DUMP_METRICS_AT_EXIT", output_filename): + with gt_config.overrides(dump_metrics_at_exit=output_filename): with unittest.mock.patch( "gt4py.next.instrumentation.metrics.sources", sample_source_metrics ): metrics._dump_metrics_at_exit() - assert (output_filename is None) == (mode is None) + assert (output_filename is False) == (mode is False) if output_filename: - assert pathlib.Path(output_filename).exists() - data = json.loads(pathlib.Path(output_filename).read_text()) + assert output_filename.exists() + data = json.loads(output_filename.read_text()) assert "program1" in data assert "program2" in data - pathlib.Path(output_filename).unlink() # Clean up after test + output_filename.unlink() # Clean up after test else: - assert not pathlib.Path(explicit_output_filename).exists() - assert not pathlib.Path(auto_output_filename).exists() + assert not explicit_output_filename.exists() + assert not auto_output_filename.exists() diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_compiledb.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_compiledb.py index 7ff3525cf8..738e2e2972 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_compiledb.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_compiledb.py @@ -45,7 +45,7 @@ def test_compiledb_project_is_relocatable(compilable_source_example, clean_examp builder.build() - with tempfile.TemporaryDirectory(dir=config.BUILD_CACHE_DIR) as tmpdir: + with tempfile.TemporaryDirectory(dir=config.build_cache_dir) as tmpdir: # copy the project to a new location relocated_dir = pathlib.Path(tmpdir) / "relocated" shutil.copytree( diff --git a/tests/next_tests/unit_tests/otf_tests/test_languages.py b/tests/next_tests/unit_tests/otf_tests/test_code_specs.py similarity index 100% rename from tests/next_tests/unit_tests/otf_tests/test_languages.py rename to tests/next_tests/unit_tests/otf_tests/test_code_specs.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py index cd90208340..7e4419f03d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py @@ -16,6 +16,7 @@ from gt4py import next as gtx from gt4py._core import definitions as core_defs +from gt4py.next import config as gt_config from gt4py.next.program_processors.runners.dace.workflow import ( backend as dace_wf_backend, ) @@ -92,7 +93,7 @@ def mocked_gpu_transformation(*args, **kwargs) -> dace.SDFG: monkeypatch.setattr(gtx_transformations, "gt_auto_optimize", mocked_auto_optimize) monkeypatch.setattr(gtx_transformations, "gt_gpu_transformation", mocked_gpu_transformation) - with mock.patch("gt4py.next.config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE", on_gpu): + with gt_config.overrides(unstructured_horizontal_has_unit_stride=on_gpu): custom_backend = dace_wf_backend.make_dace_backend( gpu=on_gpu, cached=False, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py index 88eaffe345..58cc3c8dcb 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py @@ -25,6 +25,7 @@ common as dace_wf_common, ) from gt4py.next.type_system import type_specifications as ts +from gt4py.next import config as gt_config from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( V2E, @@ -97,7 +98,7 @@ def test_find_constant_symbols(has_unit_stride, disable_field_origin): ], ) - with mock.patch("gt4py.next.config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE", has_unit_stride): + with gt_config.overrides(unstructured_horizontal_has_unit_stride=has_unit_stride): sdfg = _translate_gtir_to_sdfg( ir=ir, offset_provider=SKIP_VALUE_MESH.offset_provider, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index 96d8c6e27c..36dc23eefc 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -55,12 +55,14 @@ def test_backend_factory_trait_cached(): def test_backend_factory_build_cache_config(monkeypatch): - monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.SESSION) - session_version = gtfn.GTFNBackendFactory() - monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.PERSISTENT) - persistent_version = gtfn.GTFNBackendFactory() + with config.overrides(build_cache_lifetime=config.BuildCacheLifetime.SESSION): + session_version = gtfn.GTFNBackendFactory() assert session_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.SESSION + + with config.overrides(build_cache_lifetime=config.BuildCacheLifetime.PERSISTENT): + persistent_version = gtfn.GTFNBackendFactory() + assert ( persistent_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.PERSISTENT @@ -68,15 +70,17 @@ def test_backend_factory_build_cache_config(monkeypatch): def test_backend_factory_build_type_config(monkeypatch): - monkeypatch.setattr(config, "CMAKE_BUILD_TYPE", config.CMakeBuildType.RELEASE) - release_version = gtfn.GTFNBackendFactory() - monkeypatch.setattr(config, "CMAKE_BUILD_TYPE", config.CMakeBuildType.MIN_SIZE_REL) - min_size_version = gtfn.GTFNBackendFactory() + with config.overrides(cmake_build_type=config.CMakeBuildType.RELEASE): + release_version = gtfn.GTFNBackendFactory() assert ( release_version.executor.compilation.builder_factory.cmake_build_type is config.CMakeBuildType.RELEASE ) + + with config.overrides(cmake_build_type=config.CMakeBuildType.MIN_SIZE_REL): + min_size_version = gtfn.GTFNBackendFactory() + assert ( min_size_version.executor.compilation.builder_factory.cmake_build_type is config.CMakeBuildType.MIN_SIZE_REL diff --git a/tests/next_tests/unit_tests/test_config.py b/tests/next_tests/unit_tests/test_config.py index a33bd5734a..ebbb33eb48 100644 --- a/tests/next_tests/unit_tests/test_config.py +++ b/tests/next_tests/unit_tests/test_config.py @@ -6,43 +6,431 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +import enum import os +import pathlib +from typing import Any +from unittest import mock import pytest -from gt4py.next import config +from gt4py.next._config import Config, ConfigManager, OptionDescriptor, UpdateScope + + +class TestOptionDescriptorBasics: + """Test basic OptionDescriptor functionality.""" + + def test_descriptor_attribute_access(self) -> None: + """Test attribute-style access to configuration options.""" + + class TestConfig(ConfigManager): + debug = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + assert cfg.debug is False + + def test_descriptor_with_default_value(self) -> None: + """Test that descriptor stores and returns default values.""" + + class TestConfig(ConfigManager): + name = OptionDescriptor(option_type=str, default="test") + + cfg = TestConfig() + assert cfg.name == "test" + + def test_descriptor_with_default_factory(self) -> None: + """Test that descriptor uses default_factory to compute defaults.""" + + class TestConfig(ConfigManager): + base = OptionDescriptor(option_type=int, default=10) + derived = OptionDescriptor( + option_type=int, default_factory=lambda cfg: cfg.get("base") * 2 + ) + + cfg = TestConfig() + assert cfg.derived == 20 + + +class TestStringValueParsing: + """Test environment variable parsing and configuration.""" + + @pytest.mark.parametrize( + "value,expected", + [ + ("False", False), + ("false", False), + ("0", False), + ("off", False), + ("True", True), + ("true", True), + ("1", True), + ("on", True), + ], + ) + def test_parse_bool(self, value, expected) -> None: + """Test parsing boolean environment variables.""" + with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": value}): + + class TestConfig(ConfigManager): + testing_opt = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + assert cfg.testing_opt is expected + + @pytest.mark.parametrize( + "value,expected", + [ + ("42", 42), + ("-5", -5), + ("0", 0), + ], + ) + def test_parse_int(self, value, expected) -> None: + """Test parsing integer environment variables.""" + with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": value}): + + class TestConfig(ConfigManager): + testing_opt = OptionDescriptor(option_type=int, default=0) + + cfg = TestConfig() + assert cfg.testing_opt == expected + + @pytest.mark.parametrize( + "value,expected", + [ + ("/tmp/test", pathlib.Path("/tmp/test")), + ("./relative/path", pathlib.Path("./relative/path")), + ("~/user/path", pathlib.Path(os.path.expanduser("~/user/path"))), + ], + ) + def test_parse_path(self, value, expected) -> None: + """Test parsing pathlib.Path environment variables.""" + with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": value}): + + class TestConfig(ConfigManager): + testing_opt = OptionDescriptor(option_type=pathlib.Path, default=pathlib.Path("/")) + + cfg = TestConfig() + assert cfg.testing_opt == expected + + def test_parse_enum(self) -> None: + """Test parsing enum options from environment variables.""" + + class Mode(enum.Enum): + DEBUG = "debug" + RELEASE = "release" + + with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": "DEBUG"}): + + class TestConfig(ConfigManager): + testing_opt = OptionDescriptor(option_type=Mode, default=Mode.RELEASE) + + cfg = TestConfig() + assert cfg.testing_opt == Mode.DEBUG + + def test_custom_parser(self) -> None: + """Test custom parser for environment variables.""" + + def parse_list(s: str) -> list[str]: + return s.split(",") + + with mock.patch.dict(os.environ, {"GT4PY_ITEMS": "a,b,c"}): + + class TestConfig(ConfigManager): + items = OptionDescriptor(option_type=list, default=[], env_var_parser=parse_list) + + cfg = TestConfig() + assert cfg.items == ["a", "b", "c"] + + def test_invalid_environment_variable_raises_error(self) -> None: + """Test that invalid environment variables raise RuntimeError.""" + with mock.patch.dict(os.environ, {"GT4PY_NUM": "not_a_number"}): + with pytest.raises(RuntimeError, match="Parsing"): + + class TestConfig(ConfigManager): + num = OptionDescriptor(option_type=int, default=0) + + TestConfig() + + +class TestConfigManagerBasics: + """Test ConfigManager basic functionality.""" + + def test_set_changes_global_value(self) -> None: + """Test that set() changes the global configuration value.""" + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + cfg.set("value", 20) + assert cfg.value == 20 + + def test_set_via_attribute_assignment(self) -> None: + """Test that setting via attribute assignment works.""" + + class TestConfig(ConfigManager): + debug = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + cfg.debug = True + assert cfg.debug is True + + def test_get_rejects_unrecognized_option(self) -> None: + """Test that get() raises AttributeError for unknown options.""" + + class TestConfig(ConfigManager): + opt = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + with pytest.raises(AttributeError, match="Unrecognized config option"): + cfg.get("nonexistent") + + def test_set_rejects_unrecognized_option(self) -> None: + """Test that set() raises AttributeError for unknown options.""" + + class TestConfig(ConfigManager): + opt = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + with pytest.raises(AttributeError, match="Unrecognized config option"): + cfg.set("nonexistent", True) + + def test_set_blocked_during_context_override(self) -> None: + """Test that set() is blocked while option is overridden in context.""" + + class TestConfig(ConfigManager): + opt = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + with cfg.overrides(opt=20): + with pytest.raises(AttributeError, match="overridden in a context manager"): + cfg.set("opt", 30) + + def test_as_dict_returns_all_options(self) -> None: + """Test that as_dict() returns all configuration options.""" + + class TestConfig(ConfigManager): + opt1 = OptionDescriptor(option_type=int, default=1) + opt2 = OptionDescriptor(option_type=str, default="test") + + cfg = TestConfig() + config_dict = cfg.as_dict() + assert config_dict["opt1"] == 1 + assert config_dict["opt2"] == "test" + + def test_as_dict_reflects_context_overrides(self) -> None: + """Test that as_dict() reflects active context overrides.""" + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + with cfg.overrides(value=99): + assert cfg.as_dict()["value"] == 99 + + +class TestConfigurationPrecedence: + """Test configuration value precedence rules.""" + + def test_environment_variable_overrides_default(self) -> None: + """Test that environment variables override descriptor defaults.""" + with mock.patch.dict(os.environ, {"GT4PY_VALUE": "999"}): + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=100) + + cfg = TestConfig() + assert cfg.value == 999 + + def test_context_override_precedence_chain(self) -> None: + """Test complete precedence: context > global > environment > default.""" + with mock.patch.dict(os.environ, {"GT4PY_NUM": "50"}): + + class TestConfig(ConfigManager): + num = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + assert cfg.num == 50 # Environment overrides default + + cfg.set("num", 100) + assert cfg.num == 100 # Global overrides environment + + with cfg.overrides(num=200): + assert cfg.num == 200 # Context overrides global + + assert cfg.num == 100 # Back to global after context + + def test_multiple_option_override(self) -> None: + """Test overriding multiple options simultaneously.""" + + class TestConfig(ConfigManager): + opt1 = OptionDescriptor(option_type=int, default=1) + opt2 = OptionDescriptor(option_type=str, default="a") + opt3 = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + with cfg.overrides(opt1=10, opt2="b", opt3=True): + assert cfg.opt1 == 10 + assert cfg.opt2 == "b" + assert cfg.opt3 is True + + def test_nested_context_overrides(self) -> None: + """Test nested context overrides.""" + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=1) + + cfg = TestConfig() + with cfg.overrides(value=10): + assert cfg.value == 10 + with cfg.overrides(value=20): + assert cfg.value == 20 + assert cfg.value == 10 + assert cfg.value == 1 + + def test_override_rejects_unrecognized_options(self) -> None: + """Test that overrides reject unknown option names.""" + + class TestConfig(ConfigManager): + opt = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + with pytest.raises(AttributeError, match="Unrecognized config options"): + with cfg.overrides(nonexistent=True): + pass + + def test_override_no_change_for_same_value(self) -> None: + """Test that overriding with same value doesn't trigger unnecessary changes.""" + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + with cfg.overrides(value=10): + assert cfg.value == 10 + + +class TestValidation: + """Test configuration option validation.""" + + def test_validator_rejects_invalid_values(self) -> None: + """Test that validators reject invalid values.""" + + def positive_int(val: Any) -> None: + if not isinstance(val, int) or val <= 0: + raise ValueError("Must be positive") + + class TestConfig(ConfigManager): + count = OptionDescriptor(option_type=int, default=1, validator=positive_int) + + cfg = TestConfig() + with pytest.raises(ValueError, match="Must be positive"): + cfg.set("count", -5) + + def test_type_check_validator(self) -> None: + """Test that 'type_check' validator validates types.""" + + class TestConfig(ConfigManager): + name = OptionDescriptor(option_type=str, default="test", validator="type_check") + + cfg = TestConfig() + with pytest.raises(TypeError): + cfg.set("name", 123) + + def test_validator_accepts_valid_values(self) -> None: + """Test that validators accept valid values.""" + + def even_int(val: Any) -> None: + if not isinstance(val, int) or val % 2 != 0: + raise ValueError("Must be even") + + class TestConfig(ConfigManager): + num = OptionDescriptor(option_type=int, default=2, validator=even_int) + + cfg = TestConfig() + cfg.set("num", 42) + assert cfg.num == 42 + + def test_validator_applied_during_context_override(self) -> None: + """Test that validators are applied during context overrides.""" + + def positive(val: Any) -> None: + if val <= 0: + raise ValueError("Must be positive") + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=1, validator=positive) + + cfg = TestConfig() + with pytest.raises(ValueError, match="Must be positive"): + with cfg.overrides(value=-1): + pass + + +class TestUpdateCallbacks: + """Test option update callbacks.""" + + def test_callback_invoked_on_global_set(self) -> None: + """Test that callbacks are invoked when using set().""" + callback_calls: list[tuple[Any, Any, UpdateScope]] = [] + + def track_changes(new_val: Any, old_val: Any, scope: UpdateScope) -> None: + callback_calls.append((new_val, old_val, scope)) + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10, update_callback=track_changes) + + cfg = TestConfig() + cfg.set("value", 20) + + assert len(callback_calls) == 1 + assert callback_calls[0] == (20, 10, UpdateScope.GLOBAL) + + def test_callback_invoked_on_context_override(self) -> None: + """Test that callbacks are invoked during context overrides.""" + callback_calls: list[tuple[Any, Any, UpdateScope]] = [] + + def track_changes(new_val: Any, old_val: Any, scope: UpdateScope) -> None: + callback_calls.append((new_val, old_val, scope)) + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10, update_callback=track_changes) + + cfg = TestConfig() + with cfg.overrides(value=20): + pass + + # Should have one call on enter and one on exit + assert any(call[2] == UpdateScope.CONTEXT for call in callback_calls) + def test_no_callback_for_no_change(self) -> None: + """Test that callbacks are not invoked when override value equals current value.""" + callback_calls: list[Any] = [] -@pytest.fixture -def env_var(): - """Just in case another test will ever use that environment variable.""" - env_var_name = "GT4PY_TEST_ENV_VAR" - saved = os.environ.get(env_var_name, None) - yield env_var_name - if saved is not None: - os.environ[env_var_name] = saved - else: - _ = os.environ.pop(env_var_name, None) + def track_changes(new_val: Any, old_val: Any, scope: UpdateScope) -> None: + callback_calls.append("called") + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10, update_callback=track_changes) -@pytest.mark.parametrize("value", ["False", "false", "0", "off"]) -def test_env_flag_to_bool_false(env_var, value): - os.environ[env_var] = value - assert config.env_flag_to_bool(env_var, default=True) is False + cfg = TestConfig() + with cfg.overrides(value=10): # Same as default + pass + assert len(callback_calls) == 0 -@pytest.mark.parametrize("value", ["True", "true", "1", "on"]) -def test_env_flag_to_bool_true(env_var, value): - os.environ[env_var] = value - assert config.env_flag_to_bool(env_var, default=False) is True +def test_gt4py_config_class() -> None: + """Test the actual Config class for GT4Py.""" -def test_env_flag_to_bool_invalid(env_var): - os.environ[env_var] = "invalid value" - with pytest.raises(ValueError): - config.env_flag_to_bool(env_var, default=False) + assert isinstance(Config, type) + cfg = Config() + assert "debug" in cfg._option_descriptors_() + assert isinstance(cfg.debug, bool) -def test_env_flag_to_bool_unset(env_var): - _ = os.environ.pop(env_var, None) - assert config.env_flag_to_bool(env_var, default=False) is False + assert isinstance(cfg.build_cache_dir, pathlib.Path) + assert str(cfg.build_cache_dir).endswith(".gt4py_cache")