diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 2e593532e..e6a16926a 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -257,6 +257,7 @@ def check_name( dates: list[datetime.datetime], frequency: datetime.timedelta, raise_exception: bool = False, + is_test: bool = False, ) -> None: """Check the name of the dataset. @@ -270,12 +271,14 @@ def check_name( The frequency of the dataset. raise_exception : bool, optional Whether to raise an exception if the name is invalid. + is_test : bool, optional + Whether running in test mode. """ basename, _ = os.path.splitext(os.path.basename(self.path)) try: DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid() except Exception as e: - if raise_exception: + if raise_exception and not is_test: raise else: LOG.warning(f"Dataset name error: {e}") @@ -571,6 +574,7 @@ def __init__( config: dict, check_name: bool = False, overwrite: bool = False, + test: bool = False, use_threads: bool = False, statistics_temp_dir: str | None = None, progress: Any = None, @@ -604,11 +608,12 @@ def __init__( super().__init__(path, cache=cache) self.config = config self.check_name = check_name + self.test = test self.use_threads = use_threads self.statistics_temp_dir = statistics_temp_dir self.progress = progress - self.main_config = loader_config(config) + self.main_config = loader_config(config, is_test=test) # self.registry.delete() ?? self.tmp_statistics.delete() @@ -744,6 +749,7 @@ def _run(self) -> int: resolution=resolution, dates=dates, frequency=frequency, + is_test=self.test, ) if len(dates) != total_shape[0]: diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index 2b55d673c..d2202911c 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -18,6 +18,8 @@ from anemoi.utils.config import load_any_dict_format from earthkit.data.core.order import normalize_order_by +from anemoi.datasets.dates.groups import Groups + LOG = logging.getLogger(__name__) @@ -338,13 +340,58 @@ def _prepare_serialisation(o: Any) -> Any: return str(o) -def loader_config(config: dict) -> LoadersConfig: +def set_to_test_mode(cfg: dict) -> None: + NUMBER_OF_DATES = 4 + LOG.warning(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.") + groups = Groups(**LoadersConfig(cfg).dates) + dates = groups.provider.values + cfg["dates"] = dict( + start=dates[0], + end=dates[NUMBER_OF_DATES - 1], + frequency=groups.provider.frequency, + group_by=NUMBER_OF_DATES, + ) + + num_ensembles = count_ensembles(cfg) + + def set_element_to_test(obj): + if isinstance(obj, (list, tuple)): + for v in obj: + set_element_to_test(v) + return + + if isinstance(obj, (dict, DotDict)): + if "grid" in obj and num_ensembles > 1: + previous = obj["grid"] + obj["grid"] = "20./20." + LOG.warning(f"Running in test mode. Setting grid to {obj['grid']} instead of {previous}") + + if "number" in obj and num_ensembles > 1: + if isinstance(obj["number"], (list, tuple)): + previous = obj["number"] + obj["number"] = previous[0:3] + LOG.warning(f"Running in test mode. Setting number to {obj['number']} instead of {previous}") + + for k, v in obj.items(): + set_element_to_test(v) + + if "constants" in obj: + constants = obj["constants"] + if "param" in constants and isinstance(constants["param"], list): + constants["param"] = ["cos_latitude"] + + set_element_to_test(cfg) + + +def loader_config(config: dict, is_test: bool = False) -> LoadersConfig: """Loads and validates the configuration for dataset loaders. Parameters ---------- config : dict The configuration dictionary. + is_test : bool + If True, applies test mode to reduce dates, grid, and ensembles. Returns ------- @@ -352,6 +399,8 @@ def loader_config(config: dict) -> LoadersConfig: The validated configuration object. """ config = Config(config) + if is_test: + set_to_test_mode(config) obj = LoadersConfig(config) # yaml round trip to check that serialisation works as expected diff --git a/tests/create/test_config.py b/tests/create/test_config.py new file mode 100644 index 000000000..7608358dc --- /dev/null +++ b/tests/create/test_config.py @@ -0,0 +1,59 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import os + +from anemoi.datasets.create.config import Config +from anemoi.datasets.create.config import LoadersConfig +from anemoi.datasets.create.config import set_to_test_mode +from anemoi.datasets.dates.groups import Groups + +HERE = os.path.dirname(__file__) + + +def _load_config(name): + path = os.path.join(HERE, name) + return Config(path) + + +def test_set_to_test_mode_limits_dates_with_real_recipe(): + cfg = _load_config("concat.yaml") + + # Original config: 2020-12-30 00:00 to 2021-01-03 12:00 at 12h = 10 dates + original_groups = Groups(**LoadersConfig(cfg).dates) + original_dates = original_groups.provider.values + assert len(original_dates) == 10 + + set_to_test_mode(cfg) + + # After test mode, should produce exactly 4 dates + test_groups = Groups(**cfg["dates"]) + test_dates = test_groups.provider.values + assert len(test_dates) == 4 + assert cfg["dates"]["group_by"] == 4 + + +def test_set_to_test_mode_reduces_grid_and_ensemble(): + cfg = Config( + { + "dates": {"start": "2020-12-30 00:00:00", "end": "2021-01-03 12:00:00", "frequency": "12h"}, + "input": { + "mars": { + "grid": "0.25/0.25", + "number": [0, 1, 2, 3, 4, 5], + "param": ["2t"], + } + }, + } + ) + + set_to_test_mode(cfg) + + assert cfg["input"]["mars"]["grid"] == "20./20." + assert cfg["input"]["mars"]["number"] == [0, 1, 2]