Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/anemoi/datasets/create/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def check_name(
dates: list[datetime.datetime],
frequency: datetime.timedelta,
raise_exception: bool = False,
is_test: bool = False,
) -> None:
"""Check the name of the dataset.

Expand All @@ -270,12 +271,14 @@ def check_name(
The frequency of the dataset.
raise_exception : bool, optional
Whether to raise an exception if the name is invalid.
is_test : bool, optional
Whether running in test mode.
"""
basename, _ = os.path.splitext(os.path.basename(self.path))
try:
DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid()
except Exception as e:
if raise_exception:
if raise_exception and not is_test:
raise
else:
LOG.warning(f"Dataset name error: {e}")
Expand Down Expand Up @@ -571,6 +574,7 @@ def __init__(
config: dict,
check_name: bool = False,
overwrite: bool = False,
test: bool = False,
use_threads: bool = False,
statistics_temp_dir: str | None = None,
progress: Any = None,
Expand Down Expand Up @@ -604,11 +608,12 @@ def __init__(
super().__init__(path, cache=cache)
self.config = config
self.check_name = check_name
self.test = test
self.use_threads = use_threads
self.statistics_temp_dir = statistics_temp_dir
self.progress = progress

self.main_config = loader_config(config)
self.main_config = loader_config(config, is_test=test)

# self.registry.delete() ??
self.tmp_statistics.delete()
Expand Down Expand Up @@ -744,6 +749,7 @@ def _run(self) -> int:
resolution=resolution,
dates=dates,
frequency=frequency,
is_test=self.test,
)

if len(dates) != total_shape[0]:
Expand Down
51 changes: 50 additions & 1 deletion src/anemoi/datasets/create/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from anemoi.utils.config import load_any_dict_format
from earthkit.data.core.order import normalize_order_by

from anemoi.datasets.dates.groups import Groups

LOG = logging.getLogger(__name__)


Expand Down Expand Up @@ -338,20 +340,67 @@ def _prepare_serialisation(o: Any) -> Any:
return str(o)


def loader_config(config: dict) -> LoadersConfig:
def set_to_test_mode(cfg: dict) -> None:
NUMBER_OF_DATES = 4
LOG.warning(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.")
groups = Groups(**LoadersConfig(cfg).dates)
dates = groups.provider.values
cfg["dates"] = dict(
start=dates[0],
end=dates[NUMBER_OF_DATES - 1],
frequency=groups.provider.frequency,
group_by=NUMBER_OF_DATES,
)

num_ensembles = count_ensembles(cfg)

def set_element_to_test(obj):
if isinstance(obj, (list, tuple)):
for v in obj:
set_element_to_test(v)
return

if isinstance(obj, (dict, DotDict)):
if "grid" in obj and num_ensembles > 1:
previous = obj["grid"]
obj["grid"] = "20./20."
LOG.warning(f"Running in test mode. Setting grid to {obj['grid']} instead of {previous}")

if "number" in obj and num_ensembles > 1:
if isinstance(obj["number"], (list, tuple)):
previous = obj["number"]
obj["number"] = previous[0:3]
LOG.warning(f"Running in test mode. Setting number to {obj['number']} instead of {previous}")

for k, v in obj.items():
set_element_to_test(v)

if "constants" in obj:
constants = obj["constants"]
if "param" in constants and isinstance(constants["param"], list):
constants["param"] = ["cos_latitude"]

set_element_to_test(cfg)


def loader_config(config: dict, is_test: bool = False) -> LoadersConfig:
"""Loads and validates the configuration for dataset loaders.

Parameters
----------
config : dict
The configuration dictionary.
is_test : bool
If True, applies test mode to reduce dates, grid, and ensembles.

Returns
-------
LoadersConfig
The validated configuration object.
"""
config = Config(config)
if is_test:
set_to_test_mode(config)
obj = LoadersConfig(config)

# yaml round trip to check that serialisation works as expected
Expand Down
59 changes: 59 additions & 0 deletions tests/create/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import os

from anemoi.datasets.create.config import Config
from anemoi.datasets.create.config import LoadersConfig
from anemoi.datasets.create.config import set_to_test_mode
from anemoi.datasets.dates.groups import Groups

HERE = os.path.dirname(__file__)


def _load_config(name):
path = os.path.join(HERE, name)
return Config(path)


def test_set_to_test_mode_limits_dates_with_real_recipe():
cfg = _load_config("concat.yaml")

# Original config: 2020-12-30 00:00 to 2021-01-03 12:00 at 12h = 10 dates
original_groups = Groups(**LoadersConfig(cfg).dates)
original_dates = original_groups.provider.values
assert len(original_dates) == 10

set_to_test_mode(cfg)

# After test mode, should produce exactly 4 dates
test_groups = Groups(**cfg["dates"])
test_dates = test_groups.provider.values
assert len(test_dates) == 4
assert cfg["dates"]["group_by"] == 4


def test_set_to_test_mode_reduces_grid_and_ensemble():
cfg = Config(
{
"dates": {"start": "2020-12-30 00:00:00", "end": "2021-01-03 12:00:00", "frequency": "12h"},
"input": {
"mars": {
"grid": "0.25/0.25",
"number": [0, 1, 2, 3, 4, 5],
"param": ["2t"],
}
},
}
)

set_to_test_mode(cfg)

assert cfg["input"]["mars"]["grid"] == "20./20."
assert cfg["input"]["mars"]["number"] == [0, 1, 2]
Loading