diff --git a/src/anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py b/src/anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py index 53127011..9eb6f430 100644 --- a/src/anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py +++ b/src/anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py @@ -43,7 +43,15 @@ def _set_start_step_from_end_step_ceiled_to_24_hours(startStep, endStep, field=N return endStep - (endStep % 24), endStep -patch_registry = {"reset_24h_accumulations": _set_start_step_from_end_step_ceiled_to_24_hours} +def _set_start_step_to_zero(startStep, endStep, field=None): + # Because the data wrongly encode start_step, but end_step is correct + return 0, endStep + + +patch_registry = { + "reset_24h_accumulations": _set_start_step_from_end_step_ceiled_to_24_hours, + "set_start_step_to_zero": _set_start_step_to_zero, +} class FieldToInterval: diff --git a/src/anemoi/datasets/create/sources/accumulate_utils/interval_generators.py b/src/anemoi/datasets/create/sources/accumulate_utils/interval_generators.py index 2ece7571..61446726 100644 --- a/src/anemoi/datasets/create/sources/accumulate_utils/interval_generators.py +++ b/src/anemoi/datasets/create/sources/accumulate_utils/interval_generators.py @@ -13,6 +13,7 @@ from abc import abstractmethod from collections.abc import Iterable +from anemoi.utils.dates import as_datetime from anemoi.utils.dates import frequency_to_timedelta from anemoi.datasets.create.sources.accumulate_utils.covering_intervals import SignedInterval @@ -149,6 +150,85 @@ def __call__( return intervals +class CycleIntervalProvider(SearchableIntervalGenerator): + def __init__(self, **config: dict): + print(f"CycleIntervalProvider config: {config}") + self.reference = config.pop("start", datetime.datetime(1970, 1, 1, 0, 0)) + self.reference = as_datetime(self.reference) + + def split(s): + i, j = s.split("-") + return int(i), int(j) + + def normalise_steps(base_time, steps): + steps = steps.split("/") + return base_time, [split(s) for s in steps] + + self.config = {split(k): normalise_steps(*v) for k, v in config.items()} + + def covering_intervals(self, start: datetime.datetime, end: datetime.datetime) -> Iterable[SignedInterval]: + cycle_length_in_hours = max([k[1] for k in self.config.keys()]) + + assert end > start, "CycleIntervalProvider only supports positive intervals (end must be after start)" + + i_start = (int((start - self.reference).total_seconds()) // 3600) % cycle_length_in_hours + i_end = (int((end - self.reference).total_seconds()) // 3600) % cycle_length_in_hours + if i_end == 0: + i_end = cycle_length_in_hours + + if not (0 <= i_start < cycle_length_in_hours): + raise ValueError( + f"CycleIntervalProvider: i_start={i_start} out of range [0, {cycle_length_in_hours}) (start={start})" + ) + if not (0 < i_end <= cycle_length_in_hours): + raise ValueError( + f"CycleIntervalProvider: i_end={i_end} out of range (0, {cycle_length_in_hours}] (end={end})" + ) + if i_start >= i_end: + raise ValueError(f"CycleIntervalProvider: i_start={i_start} >= i_end={i_end} (start={start}, end={end})") + + if (i_start, i_end) not in self.config: + raise ValueError( + f"CycleIntervalProvider: no config to find ({i_start}, {i_end}) (start={start}, end={end}, {cycle_length_in_hours=})" + ) + + base_time, steps = self.config[(i_start, i_end)] + + base_datetime = datetime.datetime(end.year, end.month, end.day, base_time) + # The base must be strictly before end so step 0-N lands on or before end. + while base_datetime >= end: + base_datetime -= datetime.timedelta(days=1) + + if base_datetime.hour != base_time: + raise ValueError(f"base_datetime hour {base_datetime.hour} does not match expected base_time {base_time}") + if base_datetime >= end: + raise ValueError(f"base_datetime {base_datetime} must be strictly before end {end}") + + intervals = [] + for start_step, end_step in steps: + if start_step < 0: + raise ValueError(f"start_step {start_step} must be non-negative") + if end_step <= start_step: + raise ValueError(f"end_step {end_step} must be greater than start_step {start_step}") + interval = SignedInterval( + base=base_datetime, + start=base_datetime + datetime.timedelta(hours=start_step), + end=base_datetime + datetime.timedelta(hours=end_step), + ) + intervals.append(interval) + + if not (any(i.start == start for i in intervals) or any((-i).start == start for i in intervals)): + raise ValueError( + f"CycleIntervalProvider: no interval starting at {start} (start={start}, end={end}, {cycle_length_in_hours=})" + ) + if not (any(i.end == end for i in intervals) or any((-i).end == end for i in intervals)): + raise ValueError( + f"CycleIntervalProvider: no interval ending at {end} (start={start}, end={end}, {cycle_length_in_hours=})" + ) + + return intervals + + def normalise_steps(steps_list: str | list[str]) -> list[list[int]]: """Convert the input step_list to a list of [start,end] pairs""" res = [] @@ -263,6 +343,11 @@ def _interval_generator_factory( case {"accumulated-from-start": params}: return AccumulatedFromStartIntervalGenerator(**params) + case {"type": "cycle", **params}: + return CycleIntervalProvider(**params) + case {"cycle": params}: + return CycleIntervalProvider(**params) + case {"accumulated-from-previous-step": params}: return AccumulatedFromPreviousStepIntervalGenerator(**params) case {"type": "accumulated-from-previous-step", **params}: