Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,15 @@ def _set_start_step_from_end_step_ceiled_to_24_hours(startStep, endStep, field=N
return endStep - (endStep % 24), endStep


patch_registry = {"reset_24h_accumulations": _set_start_step_from_end_step_ceiled_to_24_hours}
def _set_start_step_to_zero(startStep, endStep, field=None):
# Because the data wrongly encode start_step, but end_step is correct
return 0, endStep


patch_registry = {
"reset_24h_accumulations": _set_start_step_from_end_step_ceiled_to_24_hours,
"set_start_step_to_zero": _set_start_step_to_zero,
}


class FieldToInterval:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from abc import abstractmethod
from collections.abc import Iterable

from anemoi.utils.dates import as_datetime
from anemoi.utils.dates import frequency_to_timedelta

from anemoi.datasets.create.sources.accumulate_utils.covering_intervals import SignedInterval
Expand Down Expand Up @@ -149,6 +150,85 @@ def __call__(
return intervals


class CycleIntervalProvider(SearchableIntervalGenerator):
def __init__(self, **config: dict):
print(f"CycleIntervalProvider config: {config}")
self.reference = config.pop("start", datetime.datetime(1970, 1, 1, 0, 0))
self.reference = as_datetime(self.reference)

def split(s):
i, j = s.split("-")
return int(i), int(j)

def normalise_steps(base_time, steps):
steps = steps.split("/")
return base_time, [split(s) for s in steps]

self.config = {split(k): normalise_steps(*v) for k, v in config.items()}

def covering_intervals(self, start: datetime.datetime, end: datetime.datetime) -> Iterable[SignedInterval]:
cycle_length_in_hours = max([k[1] for k in self.config.keys()])

assert end > start, "CycleIntervalProvider only supports positive intervals (end must be after start)"

i_start = (int((start - self.reference).total_seconds()) // 3600) % cycle_length_in_hours
i_end = (int((end - self.reference).total_seconds()) // 3600) % cycle_length_in_hours
if i_end == 0:
i_end = cycle_length_in_hours

if not (0 <= i_start < cycle_length_in_hours):
raise ValueError(
f"CycleIntervalProvider: i_start={i_start} out of range [0, {cycle_length_in_hours}) (start={start})"
)
if not (0 < i_end <= cycle_length_in_hours):
raise ValueError(
f"CycleIntervalProvider: i_end={i_end} out of range (0, {cycle_length_in_hours}] (end={end})"
)
if i_start >= i_end:
raise ValueError(f"CycleIntervalProvider: i_start={i_start} >= i_end={i_end} (start={start}, end={end})")

if (i_start, i_end) not in self.config:
raise ValueError(
f"CycleIntervalProvider: no config to find ({i_start}, {i_end}) (start={start}, end={end}, {cycle_length_in_hours=})"
)

base_time, steps = self.config[(i_start, i_end)]

base_datetime = datetime.datetime(end.year, end.month, end.day, base_time)
# The base must be strictly before end so step 0-N lands on or before end.
while base_datetime >= end:
base_datetime -= datetime.timedelta(days=1)

if base_datetime.hour != base_time:
raise ValueError(f"base_datetime hour {base_datetime.hour} does not match expected base_time {base_time}")
if base_datetime >= end:
raise ValueError(f"base_datetime {base_datetime} must be strictly before end {end}")

intervals = []
for start_step, end_step in steps:
if start_step < 0:
raise ValueError(f"start_step {start_step} must be non-negative")
if end_step <= start_step:
raise ValueError(f"end_step {end_step} must be greater than start_step {start_step}")
interval = SignedInterval(
base=base_datetime,
start=base_datetime + datetime.timedelta(hours=start_step),
end=base_datetime + datetime.timedelta(hours=end_step),
)
intervals.append(interval)

if not (any(i.start == start for i in intervals) or any((-i).start == start for i in intervals)):
raise ValueError(
f"CycleIntervalProvider: no interval starting at {start} (start={start}, end={end}, {cycle_length_in_hours=})"
)
if not (any(i.end == end for i in intervals) or any((-i).end == end for i in intervals)):
raise ValueError(
f"CycleIntervalProvider: no interval ending at {end} (start={start}, end={end}, {cycle_length_in_hours=})"
)

return intervals


def normalise_steps(steps_list: str | list[str]) -> list[list[int]]:
"""Convert the input step_list to a list of [start,end] pairs"""
res = []
Expand Down Expand Up @@ -263,6 +343,11 @@ def _interval_generator_factory(
case {"accumulated-from-start": params}:
return AccumulatedFromStartIntervalGenerator(**params)

case {"type": "cycle", **params}:
return CycleIntervalProvider(**params)
case {"cycle": params}:
return CycleIntervalProvider(**params)

case {"accumulated-from-previous-step": params}:
return AccumulatedFromPreviousStepIntervalGenerator(**params)
case {"type": "accumulated-from-previous-step", **params}:
Expand Down
Loading