-
Notifications
You must be signed in to change notification settings - Fork 31
Fix selection #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix selection #70
Changes from 28 commits
13557aa
d3802f2
d7924dc
868558f
7d72542
824d45d
ff249e9
01df7c3
ede3b9a
570bc65
4dcba28
b23e426
d19b304
b38191c
a615ae7
87760f9
747986a
da5c70c
0f1466f
3cd8933
381bc64
85b5c27
5e217d1
96f5d93
7fc6f37
7b2845b
f5fd875
218c738
e993334
779530c
4d52059
3542f8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| from dataclasses import dataclass, field | ||
| from datetime import datetime, timedelta | ||
| from typing import Any, Dict, List, Optional, Union | ||
|
|
||
| import dataclass_wizard | ||
|
|
@@ -72,9 +73,9 @@ class Range: | |
| then the entire range will be selected. | ||
| """ | ||
|
|
||
| start: Union[str, int, float] | ||
| end: Union[str, int, float] | ||
| step: Optional[Union[str, int, float]] = None | ||
| start: Union[str, int, float, datetime] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, I'm an idiot. I am not sure we need to support
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. which I think should mean we should remove
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I imagine some folks might still write times as strings, but not allowing it in the first place would probably be more robust, so I'd be happy to drop
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I'm confused. Our current config schema supports the following: schema_version: v0.6.0
dataset_version: v0.1.0
output:
variables:
static: [grid_index, static_feature]
state: [time, grid_index, state_feature]
forcing: [time, grid_index, forcing_feature]
coord_ranges:
time:
start: 1990-09-03T00:00
end: 1990-09-09T00:00
step: PT3H
...What I am suggesting is that if we replace the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was just experimenting together with @matschreiner and There seems to be an issue with a roundtrip though, as
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm that is strange. Omitting the seconds should be valid, https://en.wikipedia.org/wiki/ISO_8601#Times
Maybe this is an upstream bug, or YAML requires the seconds be included...
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In
this seems ok, no? It does mean the round trip wouldn't result in exactly the same yaml because the minutes will be added by |
||
| end: Union[str, int, float, datetime] | ||
| step: Union[str, int, float, timedelta, None] = None | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,32 +1,19 @@ | ||
| import datetime | ||
| import warnings | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
|
|
||
| from ..config import Range | ||
|
|
||
| def to_timestamp(s): | ||
| if isinstance(s, str): | ||
| return pd.Timestamp(s) | ||
| return s | ||
|
|
||
| def _normalize_slice_startstop(s): | ||
| if isinstance(s, pd.Timestamp): | ||
| return s | ||
| elif isinstance(s, str): | ||
| try: | ||
| return pd.Timestamp(s) | ||
| except ValueError: | ||
| return s | ||
| else: | ||
| return s | ||
|
|
||
|
|
||
| def _normalize_slice_step(s): | ||
| if isinstance(s, pd.Timedelta): | ||
| return s | ||
| elif isinstance(s, str): | ||
| try: | ||
| return pd.to_timedelta(s) | ||
| except ValueError: | ||
| return s | ||
| else: | ||
| return s | ||
| def to_timedelta(s): | ||
| if isinstance(s, str): | ||
| return np.timedelta64(pd.to_timedelta(s)) | ||
| return s | ||
|
|
||
|
|
||
| def select_by_kwargs(ds, **coord_ranges): | ||
|
|
@@ -56,64 +43,44 @@ def select_by_kwargs(ds, **coord_ranges): | |
| """ | ||
|
|
||
| for coord, selection in coord_ranges.items(): | ||
| if coord not in ds.coords: | ||
| raise ValueError(f"Coordinate {coord} not found in dataset") | ||
| if isinstance(selection, Range): | ||
| if selection.start is None and selection.end is None: | ||
| raise ValueError( | ||
| f"Selection for coordinate {coord} must have either 'start' and 'end' given" | ||
| ) | ||
| sel_start = _normalize_slice_startstop(selection.start) | ||
| sel_end = _normalize_slice_startstop(selection.end) | ||
| sel_step = _normalize_slice_step(selection.step) | ||
|
|
||
| assert sel_start != sel_end, "Start and end cannot be the same" | ||
|
|
||
| # we don't select with the step size for now, but simply check (below) that | ||
| # the step size in the data is the same as the requested step size | ||
| ds = ds.sel({coord: slice(sel_start, sel_end)}) | ||
|
|
||
| if coord == "time": | ||
| check_point_in_dataset(coord, sel_start, ds) | ||
| check_point_in_dataset(coord, sel_end, ds) | ||
| if sel_step is not None: | ||
| check_step(sel_step, coord, ds) | ||
|
|
||
| assert ( | ||
| len(ds[coord]) > 0 | ||
| ), f"You have selected an empty range {sel_start}:{sel_end} for coordinate {coord}" | ||
|
|
||
| elif isinstance(selection, list): | ||
| ds = ds.sel({coord: selection}) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Selection for coordinate {coord} must be a list or a dict" | ||
| ) | ||
| sel_start = selection.start | ||
| sel_end = selection.end | ||
| sel_step = selection.step | ||
|
|
||
| if coord == "time": | ||
| sel_start = to_timestamp(selection.start) | ||
| sel_end = to_timestamp(selection.end) | ||
| sel_step = get_time_step(sel_step, ds) | ||
|
Comment on lines
+50
to
+53
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @leifdenby that it would be good to only accept datetimes and so this rather boilerplate code is not necessary |
||
|
|
||
| assert sel_start != sel_end, "Start and end cannot be the same" | ||
|
|
||
|
Comment on lines
+55
to
+56
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| check_selection(ds, coord, sel_start, sel_end) | ||
| ds = ds.sel({coord: slice(sel_start, sel_end, sel_step)}) | ||
|
|
||
| assert ( | ||
| len(ds[coord]) > 0 | ||
| ), f"You have selected an empty range {sel_start}:{sel_end} for coordinate {coord}" | ||
|
|
||
| return ds | ||
|
|
||
|
|
||
| def check_point_in_dataset(coord, point, ds): | ||
| """ | ||
| check that the requested point is in the data. | ||
| """ | ||
| if point is not None and point not in ds[coord].values: | ||
| def get_time_step(sel_step, ds): | ||
| if sel_step is None: | ||
| return None | ||
|
|
||
| dataset_timedelta = ds.time[1] - ds.time[0] | ||
| sel_timedelta = to_timedelta(sel_step) | ||
| step = sel_timedelta / dataset_timedelta | ||
| if step % 1 != 0: | ||
| raise ValueError( | ||
| f"Provided value for coordinate {coord} ({point}) is not in the data." | ||
| f"The chosen stepsize {sel_step} is not multiple of the stepsize in the dataset {dataset_timedelta}" | ||
| ) | ||
|
|
||
| return int(step) | ||
|
|
||
| def check_step(sel_step, coord, ds): | ||
| """ | ||
| check that the step requested is exactly what the data has | ||
| """ | ||
| all_steps = ds[coord].diff(dim=coord).values | ||
| first_step = all_steps[0].astype("timedelta64[s]").astype(datetime.timedelta) | ||
|
|
||
| if not all(all_steps[0] == all_steps): | ||
| raise ValueError( | ||
| f"Step size for coordinate {coord} is not constant: {all_steps}" | ||
| ) | ||
| if sel_step != first_step: | ||
| raise ValueError( | ||
| f"Step size for coordinate {coord} is not the same as requested: {first_step} != {sel_step}" | ||
| def check_selection(ds, coord, sel_start, sel_end): | ||
| if ds[coord].values.min() < sel_start or ds[coord].values.max() > sel_end: | ||
| warnings.warn( | ||
| f"\nChosen slice exceeds the range of {coord} in the dataset.\n Dataset span: [ {ds[coord].values.min()} : {ds[coord].values.max()} ]\n Chosen slice: [ {sel_start} : {sel_end} ]\n" | ||
|
Comment on lines
81
to
+85
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check assumes that the coordinates are sorted in ascending order, however, this is not always the case. Example: import xarray as xr
import numpy as np
# Create geographic coordinates
latitude = np.array([10, 5, 0, -5, -10]) # Latitude in descending order
longitude = np.array([30, 35, 40]) # Longitude in ascending order
# Create some sample data
data = np.random.rand(len(latitude), len(longitude)) # 2D data with latitudes and longitudes
# Create an xarray Dataset with geographic coordinates
ds = xr.Dataset(
{
'temperature': (('latitude', 'longitude'), data), # 2D data with latitude and longitude dimensions
'precipitation': (('latitude', 'longitude'), data * 0.1), # Example precipitation data
},
coords={
'latitude': latitude,
'longitude': longitude,
}
)
ds.sel({'latitude': slice(8,-5)}) # this slice is within bounds but would raise the above warningIt should also be Nevertheless, I would not do this test at all. The user should now what a valid range is. |
||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| schema_version: v0.6.0 | ||
| dataset_version: v0.1.0 | ||
|
|
||
| output: | ||
| variables: | ||
| state: [time, grid_index, state_feature] | ||
| chunking: | ||
| time: 1 | ||
| splitting: | ||
| dim: time | ||
| splits: | ||
| train: | ||
| start: 1990-09-03T00:00 | ||
| end: 1990-09-06T00:00 | ||
| compute_statistics: | ||
| ops: [mean, std, diff_mean, diff_std] | ||
| dims: [grid_index, time] | ||
| val: | ||
| start: 1990-09-06T00:00 | ||
| end: 1990-09-07T00:00 | ||
| test: | ||
| start: 1990-09-07T00:00 | ||
| end: 1990-09-09T00:00 | ||
|
|
||
| inputs: | ||
| danra_height_levels: | ||
| path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr | ||
| dims: [time, x, y, altitude] | ||
| variables: | ||
| u: | ||
| altitude: | ||
| values: [100,] | ||
| units: m | ||
| v: | ||
| altitude: | ||
| values: [100, ] | ||
| units: m | ||
| dim_mapping: | ||
| time: | ||
| method: rename | ||
| dim: time | ||
| state_feature: | ||
| method: stack_variables_by_var_name | ||
| dims: [altitude] | ||
| name_format: "{var_name}{altitude}m" | ||
| grid_index: | ||
| method: stack | ||
| dims: [x, y] | ||
| coord_ranges: | ||
| x: | ||
| start: -50000 | ||
| end: -40000 | ||
| y: | ||
| start: -50000 | ||
| end: -40000 | ||
| time: | ||
| start: "1990-09-03T00:00" | ||
| end: "1990-09-09T00:00" | ||
| step: "PT3H" | ||
|
|
||
| target_output_variable: state |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -112,91 +112,6 @@ def test_merging_static_and_surface_analysis(): | |
| mdp.create_dataset_zarr(fp_config=fp_config) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("source_data_contains_time_range", [True, False]) | ||
| @pytest.mark.parametrize( | ||
| "time_stepsize", | ||
| [testdata.DT_ANALYSIS, testdata.DT_ANALYSIS * 2, testdata.DT_ANALYSIS / 2], | ||
| ) | ||
| def test_time_selection(source_data_contains_time_range, time_stepsize): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how come you are getting rid of this test? Don't you like it 😆 |
||
| """ | ||
| Check that time selection works as expected, so that when source | ||
| data doesn't contain the time range specified in the config and exception | ||
| is raised, and otherwise that the correct timesteps are in the output | ||
| """ | ||
|
|
||
| tmpdir = tempfile.TemporaryDirectory() | ||
| datasets = testdata.create_data_collection( | ||
| data_kinds=["surface_analysis", "static"], fp_root=tmpdir.name | ||
| ) | ||
|
|
||
| t_start_dataset = testdata.T_START | ||
| t_end_dataset = t_start_dataset + (testdata.NT_ANALYSIS - 1) * testdata.DT_ANALYSIS | ||
|
|
||
| if source_data_contains_time_range: | ||
| t_start_config = t_start_dataset | ||
| t_end_config = t_end_dataset | ||
| else: | ||
| t_start_config = t_start_dataset - testdata.DT_ANALYSIS | ||
| t_end_config = t_end_dataset + testdata.DT_ANALYSIS | ||
|
|
||
| config = dict( | ||
| schema_version=testdata.SCHEMA_VERSION, | ||
| dataset_version="v0.1.0", | ||
| output=dict( | ||
| variables=dict( | ||
| static=["grid_index", "feature"], | ||
| state=["time", "grid_index", "feature"], | ||
| forcing=["time", "grid_index", "feature"], | ||
| ), | ||
| coord_ranges=dict( | ||
| time=dict( | ||
| start=t_start_config.isoformat(), | ||
| end=t_end_config.isoformat(), | ||
| step=isodate.duration_isoformat(time_stepsize), | ||
| ) | ||
| ), | ||
| ), | ||
| inputs=dict( | ||
| danra_surface=dict( | ||
| path=datasets["surface_analysis"], | ||
| dims=["analysis_time", "x", "y"], | ||
| variables=testdata.DEFAULT_SURFACE_ANALYSIS_VARS, | ||
| dim_mapping=dict( | ||
| time=dict( | ||
| method="rename", | ||
| dim="analysis_time", | ||
| ), | ||
| grid_index=dict( | ||
| method="stack", | ||
| dims=["x", "y"], | ||
| ), | ||
| feature=dict( | ||
| method="stack_variables_by_var_name", | ||
| name_format="{var_name}", | ||
| ), | ||
| ), | ||
| target_output_variable="forcing", | ||
| ), | ||
| ), | ||
| ) | ||
|
|
||
| # write yaml config to file | ||
| fn_config = "config.yaml" | ||
| fp_config = Path(tmpdir.name) / fn_config | ||
| with open(fp_config, "w") as f: | ||
| yaml.dump(config, f) | ||
|
|
||
| # run the main function | ||
| if source_data_contains_time_range and time_stepsize == testdata.DT_ANALYSIS: | ||
| mdp.create_dataset_zarr(fp_config=fp_config) | ||
| else: | ||
| print( | ||
| f"Expecting ValueError for source_data_contains_time_range={source_data_contains_time_range} and time_stepsize={time_stepsize}" | ||
| ) | ||
| with pytest.raises(ValueError): | ||
| mdp.create_dataset_zarr(fp_config=fp_config) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("use_common_feature_var_name", [True, False]) | ||
| def test_feature_collision(use_common_feature_var_name): | ||
| """ | ||
|
|
@@ -360,7 +275,6 @@ def test_config_revision_examples(fp_example): | |
| """ | ||
| tmpdir = tempfile.TemporaryDirectory() | ||
|
|
||
| # copy example to tempdir | ||
| fp_config_copy = Path(tmpdir.name) / fp_example.name | ||
| shutil.copy(fp_example, fp_config_copy) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.