Skip to content

Commit 137ef79

Browse files
committed
Merge remote-tracking branch 'origin/jacob/netcdf' into jacob/netcdf
2 parents 2458239 + c47b7c8 commit 137ef79

File tree

6 files changed

+283
-25
lines changed

6 files changed

+283
-25
lines changed

.bumpversion.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[bumpversion]
22
commit = True
33
tag = True
4-
current_version = 2.0.9
4+
current_version = 2.0.11
55
message = Bump version: {current_version} → {new_version} [skip ci]
66

77
[bumpversion:file:setup.py]

.github/workflows/workflows.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ jobs:
1616
sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin"
1717
# brew_install: "proj geos librttopo"
1818
os_list: '["ubuntu-latest"]'
19+
python-version: "['3.9','3.10','3.11']"

ocf_datapipes/transform/xarray/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
from .get_contiguous_time_periods import (
3535
GetContiguousT0TimePeriodsIterDataPipe as GetContiguousT0TimePeriods,
3636
)
37+
from .get_contiguous_time_periods import (
38+
GetContiguousT0TimePeriodsNWPIterDataPipe as GetContiguousT0TimePeriodsNWP,
39+
)
3740
from .gsp.create_gsp_image import CreateGSPImageIterDataPipe as CreateGSPImage
3841
from .gsp.ensure_n_gsp_per_example import (
3942
EnsureNGSPSPerExampleIterDataPipe as EnsureNGSPSPerExampleIter,

ocf_datapipes/transform/xarray/get_contiguous_time_periods.py

+112-13
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,48 @@ def __iter__(self) -> pd.DataFrame:
6161
yield contiguous_time_periods
6262

6363

64-
def get_contiguous_t0_time_periods(
65-
contiguous_time_periods: pd.DataFrame, history_duration: timedelta, forecast_duration: timedelta
66-
) -> pd.DataFrame:
67-
"""Get all time periods which contain valid t0 datetimes.
64+
@functional_datapipe("get_contiguous_time_periods_nwp")
65+
class GetContiguousT0TimePeriodsNWPIterDataPipe(IterDataPipe):
66+
"""Get contiguous NWP time periods for training"""
6867

69-
`t0` is the datetime of the most recent observation.
68+
def __init__(
69+
self,
70+
source_datapipe: IterDataPipe,
71+
history_duration: timedelta,
72+
max_staleness: timedelta = timedelta(minutes=0),
73+
max_dropout: timedelta = timedelta(minutes=0),
74+
time_dim: str = "init_time_utc",
75+
):
76+
"""
77+
Get contiguous time periods for use in determing t0 times for training
7078
71-
Returns:
72-
pd.DataFrame where each row represents a single time period. The pd.DataFrame
73-
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
74-
"""
75-
contiguous_time_periods["start_dt"] += history_duration
76-
contiguous_time_periods["end_dt"] -= forecast_duration
77-
assert (contiguous_time_periods["start_dt"] < contiguous_time_periods["end_dt"]).all()
78-
return contiguous_time_periods
79+
Args:
80+
source_datapipe: Datapipe emitting a Xarray dataset
81+
history_duration: Length of the historical slice used for a sample
82+
max_staleness: Up to how long after an NWP forecast init_time are we willing to use the
83+
forecast. Each init time will only be used up to this t0 time regardless of the
84+
forecast valid time.
85+
max_dropout: What is the maximum amount of dropout that will be used. This must be <=
86+
max_staleness.
87+
time_dim: time dimensions for which to find the contiguous time periods
88+
"""
89+
self.source_datapipe = source_datapipe
90+
self.history_duration = history_duration
91+
self.max_staleness = max_staleness
92+
self.max_dropout = max_dropout
93+
self.time_dim = time_dim
94+
95+
def __iter__(self) -> pd.DataFrame:
96+
"""Calculate contiguous time periods and return a dataframe containing them"""
97+
for xr_data in self.source_datapipe:
98+
logger.debug("Getting contiguous NWP t0 time periods")
99+
contiguous_time_periods = get_contiguous_t0_periods_nwp(
100+
datetimes=pd.DatetimeIndex(xr_data[self.time_dim]),
101+
history_duration=self.history_duration,
102+
max_staleness=self.max_staleness,
103+
max_dropout=self.max_dropout,
104+
)
105+
yield contiguous_time_periods
79106

80107

81108
def get_contiguous_time_periods(
@@ -132,3 +159,75 @@ def get_contiguous_time_periods(
132159
)
133160

134161
return pd.DataFrame(periods)
162+
163+
164+
def get_contiguous_t0_time_periods(
165+
contiguous_time_periods: pd.DataFrame, history_duration: timedelta, forecast_duration: timedelta
166+
) -> pd.DataFrame:
167+
"""Get all time periods which contain valid t0 datetimes.
168+
169+
`t0` is the datetime of the most recent observation.
170+
171+
Returns:
172+
pd.DataFrame where each row represents a single time period. The pd.DataFrame
173+
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
174+
"""
175+
contiguous_time_periods["start_dt"] += history_duration
176+
contiguous_time_periods["end_dt"] -= forecast_duration
177+
assert (contiguous_time_periods["start_dt"] < contiguous_time_periods["end_dt"]).all()
178+
return contiguous_time_periods
179+
180+
181+
def get_contiguous_t0_periods_nwp(
182+
datetimes: pd.DatetimeIndex,
183+
history_duration: timedelta,
184+
max_staleness: timedelta,
185+
max_dropout: timedelta = timedelta(0),
186+
) -> pd.DataFrame:
187+
"""Get all time periods from the NWP init times which are valid as t0 datetimes.
188+
189+
Args:
190+
datetimes: Sorted pd.DatetimeIndex
191+
history_duration: Length of the historical slice used for a sample
192+
max_staleness: Up to how long after an NWP forecast init_time are we willing to use the
193+
forecast. Each init time will only be used up to this t0 time regardless of the forecast
194+
valid time.
195+
max_dropout: What is the maximum amount of dropout that will be used. This must be <=
196+
max_staleness.
197+
198+
Returns:
199+
pd.DataFrame where each row represents a single time period. The pd.DataFrame
200+
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
201+
"""
202+
# Sanity checks.
203+
assert len(datetimes) > 0
204+
assert datetimes.is_monotonic_increasing
205+
assert datetimes.is_unique
206+
assert history_duration >= timedelta(0)
207+
assert max_staleness >= timedelta(0)
208+
assert max_dropout <= max_staleness
209+
210+
hist_drop_buffer = max(history_duration, max_dropout)
211+
212+
# Store contiguous periods
213+
contiguous_periods = []
214+
215+
# Start first period allowing for history slice and max dropout
216+
start_this_period = datetimes[0] + hist_drop_buffer
217+
218+
# The first forecast is valid up to the max staleness
219+
end_this_period = datetimes[0] + max_staleness
220+
221+
for dt_init in datetimes[1:]:
222+
# If the previous init time becomes stale before the next init becomes valid whilst also
223+
# considering dropout and the need for a historic period - then the contiguous period breaks
224+
if end_this_period < dt_init + hist_drop_buffer:
225+
contiguous_periods += [[start_this_period, end_this_period]]
226+
227+
# And start a new period
228+
start_this_period = dt_init + hist_drop_buffer
229+
end_this_period = dt_init + max_staleness
230+
231+
contiguous_periods += [[start_this_period, end_this_period]]
232+
233+
return pd.DataFrame(contiguous_periods, columns=["start_dt", "end_dt"])

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setup(
1212
name="ocf_datapipes",
13-
version="2.0.9",
13+
version="2.0.11",
1414
license="MIT",
1515
description="Pytorch Datapipes built for use in Open Climate Fix's forecasting work",
1616
author="Jacob Bieker, Jack Kelly, Peter Dudfield, James Fulton",
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,172 @@
11
from datetime import timedelta
22

3-
from ocf_datapipes.select import DropGSP, LocationPicker
4-
from ocf_datapipes.transform.xarray import GetContiguousT0TimePeriods
3+
import numpy as np
4+
import pandas as pd
5+
6+
from torchdata.datapipes.iter import IterableWrapper
7+
from ocf_datapipes.transform.xarray import GetContiguousT0TimePeriods, GetContiguousT0TimePeriodsNWP
8+
9+
10+
def _remove_indexes(x, inds):
11+
xs = []
12+
i_last = -1
13+
for i in np.sort(inds):
14+
xs += [x[i_last + 1 : i]]
15+
i_last = i
16+
xs += [x[i_last + 1 :]]
17+
return pd.to_datetime(np.concatenate(xs))
518

619

720
def test_get_contiguous_time_periods(nwp_datapipe):
8-
nwp_datapipe = GetContiguousT0TimePeriods(
9-
nwp_datapipe,
10-
sample_period_duration=timedelta(hours=3),
11-
history_duration=timedelta(minutes=60),
12-
forecast_duration=timedelta(minutes=180),
13-
time_dim="init_time_utc",
21+
# Create 5-minutely data timestamps
22+
freq = timedelta(minutes=5)
23+
history_duration = timedelta(minutes=60)
24+
forecast_duration = timedelta(minutes=15)
25+
26+
datetimes = _remove_indexes(
27+
pd.date_range("2023-01-01 12:00", "2023-01-01 17:00", freq=freq),
28+
[5, 30],
29+
)
30+
31+
# Create initial datapipe
32+
time_datapipe = IterableWrapper([pd.DataFrame(datetimes, columns=["time_utc"]).to_xarray()])
33+
34+
history_duration = timedelta(minutes=60)
35+
36+
contig_t0_datapipe = GetContiguousT0TimePeriods(
37+
time_datapipe,
38+
sample_period_duration=freq,
39+
history_duration=history_duration,
40+
forecast_duration=forecast_duration,
41+
time_dim="time_utc",
42+
)
43+
44+
periods = next(iter(contig_t0_datapipe))
45+
46+
expected_results = pd.DataFrame(
47+
{
48+
"start_dt": pd.to_datetime(
49+
[
50+
"2023-01-01 13:30:00",
51+
"2023-01-01 15:35:00",
52+
]
53+
),
54+
"end_dt": pd.to_datetime(
55+
[
56+
"2023-01-01 14:10:00",
57+
"2023-01-01 16:45:00",
58+
]
59+
),
60+
},
1461
)
1562

16-
batch = next(iter(nwp_datapipe))
17-
print(batch)
63+
assert periods.equals(expected_results)
64+
65+
66+
def test_get_contiguous_time_periods_nwp():
67+
# These are the expected results of the test
68+
expected_results = [
69+
pd.DataFrame(
70+
{
71+
"start_dt": pd.to_datetime(["2023-01-01 03:00:00", "2023-01-02 03:00:00"]),
72+
"end_dt": pd.to_datetime(["2023-01-01 21:00:00", "2023-01-03 06:00:00"]),
73+
},
74+
),
75+
pd.DataFrame(
76+
{
77+
"start_dt": pd.to_datetime(
78+
[
79+
"2023-01-01 05:00:00",
80+
"2023-01-02 05:00:00",
81+
"2023-01-02 14:00:00",
82+
]
83+
),
84+
"end_dt": pd.to_datetime(
85+
[
86+
"2023-01-01 21:00:00",
87+
"2023-01-02 12:00:00",
88+
"2023-01-03 06:00:00",
89+
]
90+
),
91+
},
92+
),
93+
pd.DataFrame(
94+
{
95+
"start_dt": pd.to_datetime(
96+
[
97+
"2023-01-01 05:00:00",
98+
"2023-01-01 11:00:00",
99+
"2023-01-02 05:00:00",
100+
"2023-01-02 14:00:00",
101+
]
102+
),
103+
"end_dt": pd.to_datetime(
104+
[
105+
"2023-01-01 09:00:00",
106+
"2023-01-01 18:00:00",
107+
"2023-01-02 09:00:00",
108+
"2023-01-03 03:00:00",
109+
]
110+
),
111+
},
112+
),
113+
pd.DataFrame(
114+
{
115+
"start_dt": pd.to_datetime(
116+
[
117+
"2023-01-01 05:00:00",
118+
"2023-01-01 11:00:00",
119+
"2023-01-01 14:00:00",
120+
"2023-01-02 05:00:00",
121+
"2023-01-02 14:00:00",
122+
"2023-01-02 17:00:00",
123+
"2023-01-02 20:00:00",
124+
"2023-01-02 23:00:00",
125+
]
126+
),
127+
"end_dt": pd.to_datetime(
128+
[
129+
"2023-01-01 06:00:00",
130+
"2023-01-01 12:00:00",
131+
"2023-01-01 15:00:00",
132+
"2023-01-02 06:00:00",
133+
"2023-01-02 15:00:00",
134+
"2023-01-02 18:00:00",
135+
"2023-01-02 21:00:00",
136+
"2023-01-03 00:00:00",
137+
]
138+
),
139+
},
140+
),
141+
]
142+
143+
# Create 3-hourly init times with a few time stamps missing
144+
freq = timedelta(minutes=180)
145+
146+
datetimes = _remove_indexes(
147+
pd.date_range("2023-01-01 03:00", "2023-01-02 21:00", freq=freq),
148+
[1, 4, 5, 6, 7, 9, 10],
149+
)
150+
151+
# Choose some history durations and max stalenesses
152+
history_durations_hr = [0, 2, 2, 2]
153+
max_stalenesses_hr = [9, 9, 6, 3]
154+
155+
for i in range(len(expected_results)):
156+
history_duration = timedelta(hours=history_durations_hr[i])
157+
max_staleness = timedelta(hours=max_stalenesses_hr[i])
158+
159+
# Create initial datapipe
160+
time_datapipe = IterableWrapper(
161+
[pd.DataFrame(datetimes, columns=["init_time_utc"]).to_xarray()]
162+
)
163+
164+
time_periods = time_datapipe.get_contiguous_time_periods_nwp(
165+
history_duration=history_duration,
166+
max_staleness=max_staleness,
167+
time_dim="init_time_utc",
168+
)
169+
170+
# Check if results are as expected
171+
results = next(iter(time_periods))
172+
assert results.equals(expected_results[i])

0 commit comments

Comments
 (0)