Skip to content

Commit 62ea906

Browse files
committed
Add missing function
1 parent 5a8e15c commit 62ea906

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed

ocf_datapipes/training/common.py

+93
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,99 @@ def open_and_return_datapipes(
152152
return used_datapipes
153153

154154

155+
def get_and_return_overlapping_time_periods_and_t0(used_datapipes: dict, key_for_t0: str = "gsp"):
156+
"""
157+
Takes datapipes and obtains the overlapping time periods + t0 time datapipes
158+
159+
Args:
160+
used_datapipes: Dictionary of datapipes to compute the time intersection of
161+
key_for_t0: Key to use for the t0 datapipe
162+
163+
Returns:
164+
Dictionary of datapipes with the proper time slices selected
165+
"""
166+
datapipes_for_time_periods = [] # Using later to compute intersections
167+
datapipes_to_return = {} # Returned along with original ones
168+
t0_datapipe = None
169+
configuration = used_datapipes.pop("config")
170+
for key, datapipe in used_datapipes.items():
171+
if "topo" in key:
172+
continue
173+
if key_for_t0 in key:
174+
forked_datapipes = datapipe.fork(3, buffer_size=100)
175+
t0_datapipe = forked_datapipes[2]
176+
else:
177+
forked_datapipes = datapipe.fork(2, buffer_size=100)
178+
datapipes_to_return[key] = forked_datapipes[0]
179+
if "nwp" == key:
180+
time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods(
181+
sample_period_duration=timedelta(hours=3), # Init times are 3 hours apart
182+
history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes),
183+
forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes),
184+
time_dim="init_time_utc",
185+
)
186+
datapipes_for_time_periods.append(time_periods_datapipe)
187+
188+
if "sat" == key:
189+
time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods(
190+
sample_period_duration=timedelta(minutes=5),
191+
history_duration=timedelta(
192+
minutes=configuration.input_data.satellite.history_minutes
193+
),
194+
forecast_duration=timedelta(minutes=0),
195+
)
196+
datapipes_for_time_periods.append(time_periods_datapipe)
197+
198+
if "hrv" == key:
199+
time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods(
200+
sample_period_duration=timedelta(minutes=5),
201+
history_duration=timedelta(
202+
minutes=configuration.input_data.hrvsatellite.history_minutes
203+
),
204+
forecast_duration=timedelta(minutes=0),
205+
)
206+
datapipes_for_time_periods.append(time_periods_datapipe)
207+
208+
if "pv" == key:
209+
time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods(
210+
sample_period_duration=timedelta(minutes=5),
211+
history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes),
212+
forecast_duration=timedelta(minutes=configuration.input_data.pv.forecast_minutes),
213+
)
214+
datapipes_for_time_periods.append(time_periods_datapipe)
215+
if "gsp" == key:
216+
time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods(
217+
sample_period_duration=timedelta(minutes=30),
218+
history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes),
219+
forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes),
220+
)
221+
datapipes_for_time_periods.append(time_periods_datapipe)
222+
223+
# Now have the forked ones
224+
# find joint overlapping timer periods
225+
logger.debug("Getting joint time periods")
226+
overlapping_datapipe = datapipes_for_time_periods[0].select_overlapping_time_slice(
227+
secondary_datapipes=datapipes_for_time_periods[1:],
228+
)
229+
230+
# select time periods
231+
t0_datapipe = t0_datapipe.select_time_periods(time_periods=overlapping_datapipe)
232+
233+
num_t0_datapipes = len(datapipes_to_return.keys()) # One for each input
234+
t0_datapipes = t0_datapipe.select_t0_time(return_all_times=False).fork(
235+
num_t0_datapipes, buffer_size=100
236+
)
237+
238+
for i, key in enumerate(list(datapipes_to_return.keys())):
239+
datapipes_to_return[key + "_t0"] = t0_datapipes[i]
240+
241+
# Re-add config for later
242+
datapipes_to_return["config"] = configuration
243+
if "topo" in used_datapipes.keys():
244+
datapipes_to_return["topo"] = used_datapipes["topo"]
245+
return datapipes_to_return
246+
247+
155248
def normalize_gsp(x):
156249
"""Normalize the GSP data
157250

0 commit comments

Comments
 (0)