-
Notifications
You must be signed in to change notification settings - Fork 31
Fix selection #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix selection #70
Changes from 29 commits
13557aa
d3802f2
d7924dc
868558f
7d72542
824d45d
ff249e9
01df7c3
ede3b9a
570bc65
4dcba28
b23e426
d19b304
b38191c
a615ae7
87760f9
747986a
da5c70c
0f1466f
3cd8933
381bc64
85b5c27
5e217d1
96f5d93
7fc6f37
7b2845b
f5fd875
218c738
e993334
779530c
4d52059
3542f8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,32 +1,19 @@ | ||
| import datetime | ||
| import warnings | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
|
|
||
| from ..config import Range | ||
|
|
||
| def to_timestamp(s): | ||
| if isinstance(s, str): | ||
| return pd.Timestamp(s) | ||
| return s | ||
|
|
||
| def _normalize_slice_startstop(s): | ||
| if isinstance(s, pd.Timestamp): | ||
| return s | ||
| elif isinstance(s, str): | ||
| try: | ||
| return pd.Timestamp(s) | ||
| except ValueError: | ||
| return s | ||
| else: | ||
| return s | ||
|
|
||
|
|
||
| def _normalize_slice_step(s): | ||
| if isinstance(s, pd.Timedelta): | ||
| return s | ||
| elif isinstance(s, str): | ||
| try: | ||
| return pd.to_timedelta(s) | ||
| except ValueError: | ||
| return s | ||
| else: | ||
| return s | ||
| def to_timedelta(s): | ||
| if isinstance(s, str): | ||
| return np.timedelta64(pd.to_timedelta(s)) | ||
| return s | ||
|
|
||
|
|
||
| def select_by_kwargs(ds, **coord_ranges): | ||
|
|
@@ -56,64 +43,44 @@ def select_by_kwargs(ds, **coord_ranges): | |
| """ | ||
|
|
||
| for coord, selection in coord_ranges.items(): | ||
| if coord not in ds.coords: | ||
| raise ValueError(f"Coordinate {coord} not found in dataset") | ||
| if isinstance(selection, Range): | ||
| if selection.start is None and selection.end is None: | ||
| raise ValueError( | ||
| f"Selection for coordinate {coord} must have either 'start' and 'end' given" | ||
| ) | ||
| sel_start = _normalize_slice_startstop(selection.start) | ||
| sel_end = _normalize_slice_startstop(selection.end) | ||
| sel_step = _normalize_slice_step(selection.step) | ||
|
|
||
| assert sel_start != sel_end, "Start and end cannot be the same" | ||
|
|
||
| # we don't select with the step size for now, but simply check (below) that | ||
| # the step size in the data is the same as the requested step size | ||
| ds = ds.sel({coord: slice(sel_start, sel_end)}) | ||
|
|
||
| if coord == "time": | ||
| check_point_in_dataset(coord, sel_start, ds) | ||
| check_point_in_dataset(coord, sel_end, ds) | ||
| if sel_step is not None: | ||
| check_step(sel_step, coord, ds) | ||
|
|
||
| assert ( | ||
| len(ds[coord]) > 0 | ||
| ), f"You have selected an empty range {sel_start}:{sel_end} for coordinate {coord}" | ||
|
|
||
| elif isinstance(selection, list): | ||
| ds = ds.sel({coord: selection}) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Selection for coordinate {coord} must be a list or a dict" | ||
| ) | ||
| sel_start = selection.start | ||
| sel_end = selection.end | ||
| sel_step = selection.step | ||
|
|
||
| if coord == "time": | ||
| sel_start = to_timestamp(selection.start) | ||
| sel_end = to_timestamp(selection.end) | ||
| sel_step = get_time_step(sel_step, ds) | ||
|
Comment on lines
+50
to
+53
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @leifdenby that it would be good to only accept datetimes and so this rather boilerplate code is not necessary |
||
|
|
||
| assert sel_start != sel_end, "Start and end cannot be the same" | ||
|
|
||
|
Comment on lines
+55
to
+56
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| check_selection(ds, coord, sel_start, sel_end) | ||
| ds = ds.sel({coord: slice(sel_start, sel_end, sel_step)}) | ||
|
|
||
| assert ( | ||
| len(ds[coord]) > 0 | ||
| ), f"You have selected an empty range {sel_start}:{sel_end} for coordinate {coord}" | ||
|
|
||
| return ds | ||
|
|
||
|
|
||
| def check_point_in_dataset(coord, point, ds): | ||
| """ | ||
| check that the requested point is in the data. | ||
| """ | ||
| if point is not None and point not in ds[coord].values: | ||
| def get_time_step(sel_step, ds): | ||
| if sel_step is None: | ||
| return None | ||
|
|
||
| dataset_timedelta = ds.time[1] - ds.time[0] | ||
| sel_timedelta = to_timedelta(sel_step) | ||
| step = sel_timedelta / dataset_timedelta | ||
| if step % 1 != 0: | ||
| raise ValueError( | ||
| f"Provided value for coordinate {coord} ({point}) is not in the data." | ||
| f"The chosen stepsize {sel_step} is not multiple of the stepsize in the dataset {dataset_timedelta}" | ||
| ) | ||
|
|
||
| return int(step) | ||
|
|
||
| def check_step(sel_step, coord, ds): | ||
| """ | ||
| check that the step requested is exactly what the data has | ||
| """ | ||
| all_steps = ds[coord].diff(dim=coord).values | ||
| first_step = all_steps[0].astype("timedelta64[s]").astype(datetime.timedelta) | ||
|
|
||
| if not all(all_steps[0] == all_steps): | ||
| raise ValueError( | ||
| f"Step size for coordinate {coord} is not constant: {all_steps}" | ||
| ) | ||
| if sel_step != first_step: | ||
| raise ValueError( | ||
| f"Step size for coordinate {coord} is not the same as requested: {first_step} != {sel_step}" | ||
| def check_selection(ds, coord, sel_start, sel_end): | ||
| if ds[coord].values.min() < sel_start or ds[coord].values.max() > sel_end: | ||
| warnings.warn( | ||
| f"\nChosen slice exceeds the range of {coord} in the dataset.\n Dataset span: [ {ds[coord].values.min()} : {ds[coord].values.max()} ]\n Chosen slice: [ {sel_start} : {sel_end} ]\n" | ||
|
Comment on lines
81
to
+85
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check assumes that the coordinates are sorted in ascending order, however, this is not always the case. Example: import xarray as xr
import numpy as np
# Create geographic coordinates
latitude = np.array([10, 5, 0, -5, -10]) # Latitude in descending order
longitude = np.array([30, 35, 40]) # Longitude in ascending order
# Create some sample data
data = np.random.rand(len(latitude), len(longitude)) # 2D data with latitudes and longitudes
# Create an xarray Dataset with geographic coordinates
ds = xr.Dataset(
{
'temperature': (('latitude', 'longitude'), data), # 2D data with latitude and longitude dimensions
'precipitation': (('latitude', 'longitude'), data * 0.1), # Example precipitation data
},
coords={
'latitude': latitude,
'longitude': longitude,
}
)
ds.sel({'latitude': slice(8,-5)}) # this slice is within bounds but would raise the above warningIt should also be Nevertheless, I would not do this test at all. The user should now what a valid range is. |
||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| schema_version: v0.6.0 | ||
| dataset_version: v0.1.0 | ||
|
|
||
| output: | ||
| variables: | ||
| state: [time, grid_index, state_feature] | ||
| chunking: | ||
| time: 1 | ||
| splitting: | ||
| dim: time | ||
| splits: | ||
| train: | ||
| start: 1990-09-03T00:00 | ||
| end: 1990-09-06T00:00 | ||
| compute_statistics: | ||
| ops: [mean, std, diff_mean, diff_std] | ||
| dims: [grid_index, time] | ||
| val: | ||
| start: 1990-09-06T00:00 | ||
| end: 1990-09-07T00:00 | ||
| test: | ||
| start: 1990-09-07T00:00 | ||
| end: 1990-09-09T00:00 | ||
|
|
||
| inputs: | ||
| danra_height_levels: | ||
| path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr | ||
| dims: [time, x, y, altitude] | ||
| variables: | ||
| u: | ||
| altitude: | ||
| values: [100,] | ||
| units: m | ||
| v: | ||
| altitude: | ||
| values: [100, ] | ||
| units: m | ||
| dim_mapping: | ||
| time: | ||
| method: rename | ||
| dim: time | ||
| state_feature: | ||
| method: stack_variables_by_var_name | ||
| dims: [altitude] | ||
| name_format: "{var_name}{altitude}m" | ||
| grid_index: | ||
| method: stack | ||
| dims: [x, y] | ||
| coord_ranges: | ||
| x: | ||
| start: -50000 | ||
| end: -40000 | ||
| y: | ||
| start: -50000 | ||
| end: -40000 | ||
| time: | ||
| start: "1990-09-03T00:00" | ||
| end: "1990-09-09T00:00" | ||
| step: "PT3H" | ||
|
|
||
| target_output_variable: state |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -112,91 +112,6 @@ def test_merging_static_and_surface_analysis(): | |
| mdp.create_dataset_zarr(fp_config=fp_config) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("source_data_contains_time_range", [True, False]) | ||
| @pytest.mark.parametrize( | ||
| "time_stepsize", | ||
| [testdata.DT_ANALYSIS, testdata.DT_ANALYSIS * 2, testdata.DT_ANALYSIS / 2], | ||
| ) | ||
| def test_time_selection(source_data_contains_time_range, time_stepsize): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how come you are getting rid of this test? Don't you like it 😆 |
||
| """ | ||
| Check that time selection works as expected, so that when source | ||
| data doesn't contain the time range specified in the config and exception | ||
| is raised, and otherwise that the correct timesteps are in the output | ||
| """ | ||
|
|
||
| tmpdir = tempfile.TemporaryDirectory() | ||
| datasets = testdata.create_data_collection( | ||
| data_kinds=["surface_analysis", "static"], fp_root=tmpdir.name | ||
| ) | ||
|
|
||
| t_start_dataset = testdata.T_START | ||
| t_end_dataset = t_start_dataset + (testdata.NT_ANALYSIS - 1) * testdata.DT_ANALYSIS | ||
|
|
||
| if source_data_contains_time_range: | ||
| t_start_config = t_start_dataset | ||
| t_end_config = t_end_dataset | ||
| else: | ||
| t_start_config = t_start_dataset - testdata.DT_ANALYSIS | ||
| t_end_config = t_end_dataset + testdata.DT_ANALYSIS | ||
|
|
||
| config = dict( | ||
| schema_version=testdata.SCHEMA_VERSION, | ||
| dataset_version="v0.1.0", | ||
| output=dict( | ||
| variables=dict( | ||
| static=["grid_index", "feature"], | ||
| state=["time", "grid_index", "feature"], | ||
| forcing=["time", "grid_index", "feature"], | ||
| ), | ||
| coord_ranges=dict( | ||
| time=dict( | ||
| start=t_start_config.isoformat(), | ||
| end=t_end_config.isoformat(), | ||
| step=isodate.duration_isoformat(time_stepsize), | ||
| ) | ||
| ), | ||
| ), | ||
| inputs=dict( | ||
| danra_surface=dict( | ||
| path=datasets["surface_analysis"], | ||
| dims=["analysis_time", "x", "y"], | ||
| variables=testdata.DEFAULT_SURFACE_ANALYSIS_VARS, | ||
| dim_mapping=dict( | ||
| time=dict( | ||
| method="rename", | ||
| dim="analysis_time", | ||
| ), | ||
| grid_index=dict( | ||
| method="stack", | ||
| dims=["x", "y"], | ||
| ), | ||
| feature=dict( | ||
| method="stack_variables_by_var_name", | ||
| name_format="{var_name}", | ||
| ), | ||
| ), | ||
| target_output_variable="forcing", | ||
| ), | ||
| ), | ||
| ) | ||
|
|
||
| # write yaml config to file | ||
| fn_config = "config.yaml" | ||
| fp_config = Path(tmpdir.name) / fn_config | ||
| with open(fp_config, "w") as f: | ||
| yaml.dump(config, f) | ||
|
|
||
| # run the main function | ||
| if source_data_contains_time_range and time_stepsize == testdata.DT_ANALYSIS: | ||
| mdp.create_dataset_zarr(fp_config=fp_config) | ||
| else: | ||
| print( | ||
| f"Expecting ValueError for source_data_contains_time_range={source_data_contains_time_range} and time_stepsize={time_stepsize}" | ||
| ) | ||
| with pytest.raises(ValueError): | ||
| mdp.create_dataset_zarr(fp_config=fp_config) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("use_common_feature_var_name", [True, False]) | ||
| def test_feature_collision(use_common_feature_var_name): | ||
| """ | ||
|
|
@@ -360,7 +275,6 @@ def test_config_revision_examples(fp_example): | |
| """ | ||
| tmpdir = tempfile.TemporaryDirectory() | ||
|
|
||
| # copy example to tempdir | ||
| fp_config_copy = Path(tmpdir.name) / fp_example.name | ||
| shutil.copy(fp_example, fp_config_copy) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I'm an idiot. I am not sure we need to support
strhere since I added that for iso8601 formatted strings. I thought that yaml doesn't natively support datetime serialisation, but I missremembered, that is json! So, if people write iso8601 formatted strings then I think yaml should always turn that intodatetime.datetimeobjects if define that is the type here. That also means we don't have handle turning strings into datetime/timedelta objects in the code, nice!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which I think should mean we should remove
strhere. Do you agree @matschreiner and @observingClouds ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I imagine some folks might still write times as strings, but not allowing it in the first place would probably be more robust, so I'd be happy to drop
strsupportUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm confused. Our current config schema supports the following:
What I am suggesting is that if we replace the
strtype withdatetimein the config dataclasses then the dataclass serialisation (fromdataclasses-wizard) will handle turning thestartandendfields intodatatimeobjects, rather than us having to do it. In terms of what is in the config yaml-files they would remain unchanged though? So people can keep defining the start/end times as they already are. What was previously interpreted as a string will now simply be interpreted as adatetimeserialised as a iso8601 string, no?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was just experimenting together with @matschreiner and
1990-09-03T00:00is not a valid ISO format and is not converted to datetime and remains a string, however1990-09-03T00:00:00is.There seems to be an issue with a roundtrip though, as
mdp.Config.to_yaml()/to_dict/to_jsonall serialize datetimes back to strings...this might be an upstream issue in the dataclass-wizardThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm that is strange. Omitting the seconds should be valid, https://en.wikipedia.org/wiki/ISO_8601#Times
Maybe this is an upstream bug, or YAML requires the seconds be included...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In
dataclass-wizarddocs, https://dataclass-wizard.readthedocs.io/en/latest/overview.html#supported-typesthis seems ok, no? It does mean the round trip wouldn't result in exactly the same yaml because the minutes will be added by
isoformat()call https://docs.python.org/3/library/datetime.html#datetime.datetime.isoformat