diff --git a/mllam_data_prep/config.py b/mllam_data_prep/config.py index 8a7ccfd..3bd8189 100644 --- a/mllam_data_prep/config.py +++ b/mllam_data_prep/config.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Union import dataclass_wizard @@ -72,9 +73,9 @@ class Range: then the entire range will be selected. """ - start: Union[str, int, float] - end: Union[str, int, float] - step: Optional[Union[str, int, float]] = None + start: Union[str, int, float, datetime] + end: Union[str, int, float, datetime] + step: Union[str, int, float, timedelta, None] = None @dataclass diff --git a/mllam_data_prep/ops/selection.py b/mllam_data_prep/ops/selection.py index 37b91c1..2b68b98 100644 --- a/mllam_data_prep/ops/selection.py +++ b/mllam_data_prep/ops/selection.py @@ -1,32 +1,19 @@ -import datetime +import warnings +import numpy as np import pandas as pd -from ..config import Range +def to_timestamp(s): + if isinstance(s, str): + return pd.Timestamp(s) + return s -def _normalize_slice_startstop(s): - if isinstance(s, pd.Timestamp): - return s - elif isinstance(s, str): - try: - return pd.Timestamp(s) - except ValueError: - return s - else: - return s - -def _normalize_slice_step(s): - if isinstance(s, pd.Timedelta): - return s - elif isinstance(s, str): - try: - return pd.to_timedelta(s) - except ValueError: - return s - else: - return s +def to_timedelta(s): + if isinstance(s, str): + return np.timedelta64(pd.to_timedelta(s)) + return s def select_by_kwargs(ds, **coord_ranges): @@ -56,64 +43,44 @@ def select_by_kwargs(ds, **coord_ranges): """ for coord, selection in coord_ranges.items(): - if coord not in ds.coords: - raise ValueError(f"Coordinate {coord} not found in dataset") - if isinstance(selection, Range): - if selection.start is None and selection.end is None: - raise ValueError( - f"Selection for coordinate {coord} must have either 'start' and 'end' given" - ) - sel_start = _normalize_slice_startstop(selection.start) - sel_end = _normalize_slice_startstop(selection.end) - sel_step = _normalize_slice_step(selection.step) - - assert sel_start != sel_end, "Start and end cannot be the same" - - # we don't select with the step size for now, but simply check (below) that - # the step size in the data is the same as the requested step size - ds = ds.sel({coord: slice(sel_start, sel_end)}) - - if coord == "time": - check_point_in_dataset(coord, sel_start, ds) - check_point_in_dataset(coord, sel_end, ds) - if sel_step is not None: - check_step(sel_step, coord, ds) - - assert ( - len(ds[coord]) > 0 - ), f"You have selected an empty range {sel_start}:{sel_end} for coordinate {coord}" - - elif isinstance(selection, list): - ds = ds.sel({coord: selection}) - else: - raise NotImplementedError( - f"Selection for coordinate {coord} must be a list or a dict" - ) + sel_start = selection.start + sel_end = selection.end + sel_step = selection.step + + if coord == "time": + sel_start = to_timestamp(selection.start) + sel_end = to_timestamp(selection.end) + sel_step = get_time_step(sel_step, ds) + + assert sel_start != sel_end, "Start and end cannot be the same" + + check_selection(ds, coord, sel_start, sel_end) + ds = ds.sel({coord: slice(sel_start, sel_end, sel_step)}) + + assert ( + len(ds[coord]) > 0 + ), f"You have selected an empty range {sel_start}:{sel_end} for coordinate {coord}" + return ds -def check_point_in_dataset(coord, point, ds): - """ - check that the requested point is in the data. - """ - if point is not None and point not in ds[coord].values: +def get_time_step(sel_step, ds): + if sel_step is None: + return None + + dataset_timedelta = ds.time[1] - ds.time[0] + sel_timedelta = to_timedelta(sel_step) + step = sel_timedelta / dataset_timedelta + if step % 1 != 0: raise ValueError( - f"Provided value for coordinate {coord} ({point}) is not in the data." + f"The chosen stepsize {sel_step} is not multiple of the stepsize in the dataset {dataset_timedelta}" ) + return int(step) -def check_step(sel_step, coord, ds): - """ - check that the step requested is exactly what the data has - """ - all_steps = ds[coord].diff(dim=coord).values - first_step = all_steps[0].astype("timedelta64[s]").astype(datetime.timedelta) - if not all(all_steps[0] == all_steps): - raise ValueError( - f"Step size for coordinate {coord} is not constant: {all_steps}" - ) - if sel_step != first_step: - raise ValueError( - f"Step size for coordinate {coord} is not the same as requested: {first_step} != {sel_step}" +def check_selection(ds, coord, sel_start, sel_end): + if ds[coord].values.min() < sel_start or ds[coord].values.max() > sel_end: + warnings.warn( + f"\nChosen slice exceeds the range of {coord} in the dataset.\n Dataset span: [ {ds[coord].values.min()} : {ds[coord].values.max()} ]\n Chosen slice: [ {sel_start} : {sel_end} ]\n" ) diff --git a/tests/resources/sliced_example.danra.yaml b/tests/resources/sliced_example.danra.yaml index 6d60d85..7015021 100644 --- a/tests/resources/sliced_example.danra.yaml +++ b/tests/resources/sliced_example.danra.yaml @@ -4,11 +4,6 @@ dataset_version: v0.1.0 output: variables: state: [time, grid_index, state_feature] - coord_ranges: - time: - start: 1990-09-03T00:00 - end: 1990-09-09T00:00 - step: PT3H chunking: time: 1 splitting: @@ -58,5 +53,9 @@ inputs: y: start: -50000 end: -40000 + time: + start: 1990-09-03T00:00 + end: 1990-09-09T00:00 + step: PT3H target_output_variable: state diff --git a/tests/resources/sliced_example_with_datetime_strings.danra.yaml b/tests/resources/sliced_example_with_datetime_strings.danra.yaml new file mode 100644 index 0000000..2225619 --- /dev/null +++ b/tests/resources/sliced_example_with_datetime_strings.danra.yaml @@ -0,0 +1,61 @@ +schema_version: v0.6.0 +dataset_version: v0.1.0 + +output: + variables: + state: [time, grid_index, state_feature] + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-03T00:00 + end: 1990-09-06T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-06T00:00 + end: 1990-09-07T00:00 + test: + start: 1990-09-07T00:00 + end: 1990-09-09T00:00 + +inputs: + danra_height_levels: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr + dims: [time, x, y, altitude] + variables: + u: + altitude: + values: [100,] + units: m + v: + altitude: + values: [100, ] + units: m + dim_mapping: + time: + method: rename + dim: time + state_feature: + method: stack_variables_by_var_name + dims: [altitude] + name_format: "{var_name}{altitude}m" + grid_index: + method: stack + dims: [x, y] + coord_ranges: + x: + start: -50000 + end: -40000 + y: + start: -50000 + end: -40000 + time: + start: "1990-09-03T00:00" + end: "1990-09-09T00:00" + step: "PT3H" + + target_output_variable: state diff --git a/tests/test_config.py b/tests/test_config.py index 5f7896a..633e1a1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,10 @@ +import datetime + import pytest from dataclass_wizard.errors import MissingFields, UnknownJSONKey import mllam_data_prep as mdp +from mllam_data_prep import config INVALID_EXTRA_FIELDS_CONFIG_YAML = """ schema_version: v0.1.0 @@ -110,6 +113,16 @@ def test_get_config_issues(): """ +def test_can_load_config_with_datetime_object_in_time_range(): + fp = "tests/resources/sliced_example.danra.yaml" + mdp.Config.from_yaml_file(fp) + + +def test_can_load_config_with_datetime_string_in_time_range(): + fp = "tests/resources/sliced_example_with_datetime_strings.danra.yaml" + mdp.Config.from_yaml_file(fp) + + def test_get_config_nested(): config = mdp.Config.from_yaml(VALID_EXAMPLE_CONFIG_YAML) @@ -121,6 +134,14 @@ def test_get_config_nested(): input_config.foobarfield +def test_range_accepts_datetime(): + start = datetime.datetime(1990, 9, 3, 0, 0) + end = datetime.datetime(1990, 9, 4, 0, 0) + step = "PT3H" + + config.Range(start=start, end=end, step=step) + + def test_config_roundtrip(): original_config = mdp.Config.from_yaml(VALID_EXAMPLE_CONFIG_YAML) roundtrip_config_dict = mdp.Config.from_dict(original_config.to_dict()) diff --git a/tests/test_from_config.py b/tests/test_from_config.py index 1a89361..2e0ea45 100644 --- a/tests/test_from_config.py +++ b/tests/test_from_config.py @@ -112,91 +112,6 @@ def test_merging_static_and_surface_analysis(): mdp.create_dataset_zarr(fp_config=fp_config) -@pytest.mark.parametrize("source_data_contains_time_range", [True, False]) -@pytest.mark.parametrize( - "time_stepsize", - [testdata.DT_ANALYSIS, testdata.DT_ANALYSIS * 2, testdata.DT_ANALYSIS / 2], -) -def test_time_selection(source_data_contains_time_range, time_stepsize): - """ - Check that time selection works as expected, so that when source - data doesn't contain the time range specified in the config and exception - is raised, and otherwise that the correct timesteps are in the output - """ - - tmpdir = tempfile.TemporaryDirectory() - datasets = testdata.create_data_collection( - data_kinds=["surface_analysis", "static"], fp_root=tmpdir.name - ) - - t_start_dataset = testdata.T_START - t_end_dataset = t_start_dataset + (testdata.NT_ANALYSIS - 1) * testdata.DT_ANALYSIS - - if source_data_contains_time_range: - t_start_config = t_start_dataset - t_end_config = t_end_dataset - else: - t_start_config = t_start_dataset - testdata.DT_ANALYSIS - t_end_config = t_end_dataset + testdata.DT_ANALYSIS - - config = dict( - schema_version=testdata.SCHEMA_VERSION, - dataset_version="v0.1.0", - output=dict( - variables=dict( - static=["grid_index", "feature"], - state=["time", "grid_index", "feature"], - forcing=["time", "grid_index", "feature"], - ), - coord_ranges=dict( - time=dict( - start=t_start_config.isoformat(), - end=t_end_config.isoformat(), - step=isodate.duration_isoformat(time_stepsize), - ) - ), - ), - inputs=dict( - danra_surface=dict( - path=datasets["surface_analysis"], - dims=["analysis_time", "x", "y"], - variables=testdata.DEFAULT_SURFACE_ANALYSIS_VARS, - dim_mapping=dict( - time=dict( - method="rename", - dim="analysis_time", - ), - grid_index=dict( - method="stack", - dims=["x", "y"], - ), - feature=dict( - method="stack_variables_by_var_name", - name_format="{var_name}", - ), - ), - target_output_variable="forcing", - ), - ), - ) - - # write yaml config to file - fn_config = "config.yaml" - fp_config = Path(tmpdir.name) / fn_config - with open(fp_config, "w") as f: - yaml.dump(config, f) - - # run the main function - if source_data_contains_time_range and time_stepsize == testdata.DT_ANALYSIS: - mdp.create_dataset_zarr(fp_config=fp_config) - else: - print( - f"Expecting ValueError for source_data_contains_time_range={source_data_contains_time_range} and time_stepsize={time_stepsize}" - ) - with pytest.raises(ValueError): - mdp.create_dataset_zarr(fp_config=fp_config) - - @pytest.mark.parametrize("use_common_feature_var_name", [True, False]) def test_feature_collision(use_common_feature_var_name): """ @@ -360,7 +275,6 @@ def test_config_revision_examples(fp_example): """ tmpdir = tempfile.TemporaryDirectory() - # copy example to tempdir fp_config_copy = Path(tmpdir.name) / fp_example.name shutil.copy(fp_example, fp_config_copy) diff --git a/tests/test_selection.py b/tests/test_selection.py index 044b66e..f287617 100644 --- a/tests/test_selection.py +++ b/tests/test_selection.py @@ -1,3 +1,5 @@ +import isodate +import numpy as np import pytest import xarray as xr @@ -10,7 +12,8 @@ def ds(): Load the height_levels.zarr dataset """ fp = "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr" - return xr.open_zarr(fp) + ds = xr.open_zarr(fp) + return ds def test_range_slice_within_range(ds): @@ -32,8 +35,6 @@ def test_range_slice_within_range(ds): assert ds.y.min() >= y_start assert ds.y.max() <= y_end - ds - @pytest.mark.parametrize("x_start, x_end", ([-50000, -51000], [0, 500000])) def test_error_on_empty_range(ds, x_start, x_end): @@ -49,3 +50,42 @@ def test_error_on_empty_range(ds, x_start, x_end): with pytest.raises(AssertionError): ds = mdp.ops.selection.select_by_kwargs(ds, **coord_ranges) + + +def test_slice_time(ds): + start = "1990-09-01T00:00" + end = "1990-09-09T00:00" + coord_ranges = { + "time": mdp.config.Range(start=start, end=end), + } + + ds = mdp.ops.selection.select_by_kwargs(ds, **coord_ranges) + + +@pytest.mark.parametrize("step", ["PT6H", "PT3H"]) +def test_timestep_matches_output(ds, step): + start = "1990-09-01T00:00" + end = "1990-09-09T00:00" + coord_ranges = { + "time": mdp.config.Range(start=start, end=end, step=step), + } + + ds = mdp.ops.selection.select_by_kwargs(ds, **coord_ranges) + + td = isodate.parse_duration(step) + timestep_chosen_in_slice = np.timedelta64(int(td.total_seconds()), "s") + timestep_in_dataset = np.diff(ds.time)[0] + + assert timestep_chosen_in_slice == timestep_in_dataset + + +def test_raises_if_time_step_is_not_multiple_of_dataset_frequency(ds): + step = "PT5H" + start = "1990-09-01T03:00" + end = "1990-09-09T00:00" + coord_ranges = { + "time": mdp.config.Range(start=start, end=end, step=step), + } + + with pytest.raises(ValueError): + ds = mdp.ops.selection.select_by_kwargs(ds, **coord_ranges)