Skip to content

Commit 77be003

Browse files
authored
Merge pull request #238 from openclimatefix/jacob/netcdf
Add converting batches to NetCDF and saving them out #minor
2 parents 2631b48 + 62ea906 commit 77be003

File tree

8 files changed

+1111
-576
lines changed

8 files changed

+1111
-576
lines changed

ocf_datapipes/training/common.py

+572-1
Large diffs are not rendered by default.

ocf_datapipes/training/pvnet.py

+13-574
Large diffs are not rendered by default.

ocf_datapipes/training/windnet.py

+385
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,385 @@
1+
"""Create the training/validation datapipe for training the PVNet Model"""
2+
import logging
3+
from datetime import datetime, timedelta
4+
from typing import List, Optional, Tuple, Union
5+
6+
import xarray as xr
7+
from torchdata.datapipes import functional_datapipe
8+
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
9+
10+
from ocf_datapipes.batch import MergeNumpyModalities
11+
from ocf_datapipes.config.model import Configuration
12+
from ocf_datapipes.load import (
13+
OpenConfiguration,
14+
)
15+
from ocf_datapipes.training.common import (
16+
AddZeroedNWPData,
17+
AddZeroedSatelliteData,
18+
_get_datapipes_dict,
19+
check_nans_in_satellite_data,
20+
concat_xr_time_utc,
21+
construct_loctime_pipelines,
22+
fill_nans_in_arrays,
23+
fill_nans_in_pv,
24+
normalize_gsp,
25+
normalize_pv,
26+
slice_datapipes_by_time,
27+
)
28+
from ocf_datapipes.utils.consts import (
29+
NEW_NWP_MEAN,
30+
NEW_NWP_STD,
31+
RSS_MEAN,
32+
RSS_STD,
33+
)
34+
from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset
35+
36+
xr.set_options(keep_attrs=True)
37+
logger = logging.getLogger("windnet_datapipe")
38+
39+
40+
def scale_wind_speed_to_power(x: Union[xr.DataArray, xr.Dataset]):
41+
"""
42+
Scale wind speed to power to estimate the generation of wind power from ground sensors
43+
44+
Roughly, double speed in m/s, and convert with the power scale
45+
46+
Args:
47+
x: xr.DataArray or xr.Dataset containing wind speed
48+
49+
Returns:
50+
Rescaled wind speed to MWh roughly
51+
"""
52+
# Convert knots to m/s
53+
x = x * 0.514444
54+
# Roughly double speed to get power
55+
x = x * 2
56+
return x
57+
58+
59+
@functional_datapipe("dict_datasets")
60+
class DictDatasetIterDataPipe(IterDataPipe):
61+
"""Create a dictionary of xr.Datasets from a set of iterators"""
62+
63+
datapipes: Tuple[IterDataPipe]
64+
length: Optional[int]
65+
66+
def __init__(self, *datapipes: IterDataPipe, keys: List[str]):
67+
"""Init"""
68+
if not all(isinstance(dp, IterDataPipe) for dp in datapipes):
69+
raise TypeError(
70+
"All inputs are required to be `IterDataPipe` " "for `ZipIterDataPipe`."
71+
)
72+
super().__init__()
73+
self.keys = keys
74+
self.datapipes = datapipes # type: ignore[assignment]
75+
self.length = None
76+
assert len(self.keys) == len(self.datapipes), "Number of keys must match number of pipes"
77+
78+
def __iter__(self):
79+
"""Iter"""
80+
iterators = [iter(datapipe) for datapipe in self.datapipes]
81+
for data in zip(*iterators):
82+
# Yield a dictionary of the data, using the keys in self.keys
83+
yield {k: v for k, v in zip(self.keys, data)}
84+
85+
86+
@functional_datapipe("load_dict_datasets")
87+
class LoadDictDatasetIterDataPipe(IterDataPipe):
88+
"""Load NetCDF files and split them back into individual xr.Datasets"""
89+
90+
filenames: List[str]
91+
keys: List[str]
92+
93+
def __init__(self, filenames: List[str], keys: List[str]):
94+
"""
95+
Load NetCDF files and split them back into individual xr.Datasets
96+
97+
Args:
98+
filenames: List of filesnames to load
99+
keys: List of keys from each file to use, each key should be a
100+
dataarray in the xr.Dataset
101+
"""
102+
super().__init__()
103+
self.keys = keys
104+
self.filenames = filenames
105+
106+
def __iter__(self):
107+
"""Iterate through each filename, loading it, uncombining it, and then yielding it"""
108+
while True:
109+
for filename in self.filenames:
110+
dataset = xr.open_dataset(filename)
111+
datasets = uncombine_from_single_dataset(dataset)
112+
# Yield a dictionary of the data, using the keys in self.keys
113+
dataset_dict = {}
114+
for k in self.keys:
115+
dataset_dict[k] = datasets[k]
116+
yield dataset_dict
117+
118+
119+
@functional_datapipe("convert_to_numpy_batch")
120+
class ConvertToNumpyBatchIterDataPipe(IterDataPipe):
121+
"""Converts Xarray Dataset to Numpy Batch"""
122+
123+
def __init__(
124+
self,
125+
dataset_dict_dp: IterDataPipe,
126+
configuration: Configuration,
127+
block_sat: bool = False,
128+
block_nwp: bool = False,
129+
check_satellite_no_zeros: bool = False,
130+
):
131+
"""Init"""
132+
super().__init__()
133+
self.dataset_dict_dp = dataset_dict_dp
134+
self.configuration = configuration
135+
self.block_sat = block_sat
136+
self.block_nwp = block_nwp
137+
self.check_satellite_no_zeros = check_satellite_no_zeros
138+
139+
def __iter__(self):
140+
"""Iter"""
141+
for datapipes_dict in self.dataset_dict_dp:
142+
# Spatially slice, normalize, and convert data to numpy arrays
143+
numpy_modalities = []
144+
# Unpack for convenience
145+
conf_sat = self.configuration.input_data.satellite
146+
conf_nwp = self.configuration.input_data.nwp
147+
if "nwp" in datapipes_dict:
148+
numpy_modalities.append(datapipes_dict["nwp"].convert_nwp_to_numpy_batch())
149+
if "sat" in datapipes_dict:
150+
numpy_modalities.append(datapipes_dict["sat"].convert_satellite_to_numpy_batch())
151+
if "pv" in datapipes_dict:
152+
numpy_modalities.append(datapipes_dict["pv"].convert_pv_to_numpy_batch())
153+
numpy_modalities.append(datapipes_dict["gsp"].convert_gsp_to_numpy_batch())
154+
155+
logger.debug("Combine all the data sources")
156+
combined_datapipe = MergeNumpyModalities(numpy_modalities).add_sun_position(
157+
modality_name="gsp"
158+
)
159+
160+
if self.block_sat and conf_sat != "":
161+
sat_block_func = AddZeroedSatelliteData(self.configuration)
162+
combined_datapipe = combined_datapipe.map(sat_block_func)
163+
164+
if self.block_nwp and conf_nwp != "":
165+
nwp_block_func = AddZeroedNWPData(self.configuration)
166+
combined_datapipe = combined_datapipe.map(nwp_block_func)
167+
168+
logger.info("Filtering out samples with no data")
169+
if self.check_satellite_no_zeros:
170+
# in production we don't want any nans in the satellite data
171+
combined_datapipe = combined_datapipe.map(check_nans_in_satellite_data)
172+
173+
combined_datapipe = combined_datapipe.map(fill_nans_in_arrays)
174+
175+
yield next(iter(combined_datapipe))
176+
177+
178+
def minutes(num_mins: int):
179+
"""Timedelta of a number of minutes.
180+
181+
Args:
182+
num_mins: Minutes timedelta.
183+
"""
184+
return timedelta(minutes=num_mins)
185+
186+
187+
def construct_sliced_data_pipeline(
188+
config_filename: str,
189+
location_pipe: IterDataPipe,
190+
t0_datapipe: IterDataPipe,
191+
block_sat: bool = False,
192+
block_nwp: bool = False,
193+
production: bool = False,
194+
) -> dict:
195+
"""Constructs data pipeline for the input data config file.
196+
197+
This yields samples from the location and time datapipes.
198+
199+
Args:
200+
config_filename: Path to config file.
201+
location_pipe: Datapipe yielding locations.
202+
t0_datapipe: Datapipe yielding times.
203+
block_sat: Whether to load zeroes for satellite data.
204+
block_nwp: Whether to load zeroes for NWP data.
205+
production: Whether constucting pipeline for production inference.
206+
"""
207+
208+
assert not (production and (block_sat or block_nwp))
209+
210+
datapipes_dict = _get_datapipes_dict(
211+
config_filename,
212+
block_sat,
213+
block_nwp,
214+
production=production,
215+
)
216+
217+
configuration = datapipes_dict.pop("config")
218+
219+
# Unpack for convenience
220+
conf_sat = configuration.input_data.satellite
221+
conf_nwp = configuration.input_data.nwp
222+
223+
# Slice all of the datasets by time - this is an in-place operation
224+
slice_datapipes_by_time(datapipes_dict, t0_datapipe, configuration, production)
225+
226+
if "nwp" in datapipes_dict:
227+
nwp_datapipe = datapipes_dict["nwp"]
228+
229+
location_pipe, location_pipe_copy = location_pipe.fork(2, buffer_size=5)
230+
nwp_datapipe = nwp_datapipe.select_spatial_slice_pixels(
231+
location_pipe_copy,
232+
roi_height_pixels=conf_nwp.nwp_image_size_pixels_height,
233+
roi_width_pixels=conf_nwp.nwp_image_size_pixels_width,
234+
)
235+
nwp_datapipe = nwp_datapipe.normalize(mean=NEW_NWP_MEAN, std=NEW_NWP_STD)
236+
237+
if "sat" in datapipes_dict:
238+
sat_datapipe = datapipes_dict["sat"]
239+
240+
location_pipe, location_pipe_copy = location_pipe.fork(2, buffer_size=5)
241+
sat_datapipe = sat_datapipe.select_spatial_slice_pixels(
242+
location_pipe_copy,
243+
roi_height_pixels=conf_sat.satellite_image_size_pixels_height,
244+
roi_width_pixels=conf_sat.satellite_image_size_pixels_width,
245+
)
246+
sat_datapipe = sat_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD)
247+
248+
if "pv" in datapipes_dict:
249+
# Recombine PV arrays - see function doc for further explanation
250+
pv_datapipe = (
251+
datapipes_dict["pv"].zip_ocf(datapipes_dict["pv_future"]).map(concat_xr_time_utc)
252+
)
253+
pv_datapipe = pv_datapipe.normalize(normalize_fn=normalize_pv)
254+
pv_datapipe = pv_datapipe.map(fill_nans_in_pv)
255+
256+
# GSP always assumed to be in data
257+
location_pipe, location_pipe_copy = location_pipe.fork(2, buffer_size=5)
258+
gsp_future_datapipe = datapipes_dict["gsp_future"]
259+
gsp_future_datapipe = gsp_future_datapipe.select_spatial_slice_meters(
260+
location_datapipe=location_pipe_copy,
261+
roi_height_meters=1,
262+
roi_width_meters=1,
263+
dim_name="gsp_id",
264+
)
265+
266+
gsp_datapipe = datapipes_dict["gsp"]
267+
gsp_datapipe = gsp_datapipe.select_spatial_slice_meters(
268+
location_datapipe=location_pipe,
269+
roi_height_meters=1,
270+
roi_width_meters=1,
271+
dim_name="gsp_id",
272+
)
273+
274+
# Recombine GSP arrays - see function doc for further explanation
275+
gsp_datapipe = gsp_datapipe.zip_ocf(gsp_future_datapipe).map(concat_xr_time_utc)
276+
gsp_datapipe = gsp_datapipe.normalize(normalize_fn=normalize_gsp)
277+
278+
finished_dataset_dict = {"gsp": gsp_datapipe, "config": configuration}
279+
if "nwp" in datapipes_dict:
280+
finished_dataset_dict["nwp"] = nwp_datapipe
281+
if "sat" in datapipes_dict:
282+
finished_dataset_dict["sat"] = sat_datapipe
283+
if "pv" in datapipes_dict:
284+
finished_dataset_dict["pv"] = pv_datapipe
285+
286+
return finished_dataset_dict
287+
288+
289+
def windnet_datapipe(
290+
config_filename: str,
291+
start_time: Optional[datetime] = None,
292+
end_time: Optional[datetime] = None,
293+
block_sat: bool = False,
294+
block_nwp: bool = False,
295+
) -> IterDataPipe:
296+
"""
297+
Construct windnet pipeline for the input data config file.
298+
299+
Args:
300+
config_filename: Path to config file.
301+
start_time: Minimum time at which a sample can be selected.
302+
end_time: Maximum time at which a sample can be selected.
303+
block_sat: Whether to load zeroes for satellite data.
304+
block_nwp: Whether to load zeroes for NWP data.
305+
"""
306+
logger.info("Constructing windnet pipeline")
307+
308+
# Open datasets from the config and filter to useable location-time pairs
309+
location_pipe, t0_datapipe = construct_loctime_pipelines(
310+
config_filename,
311+
start_time,
312+
end_time,
313+
block_sat,
314+
block_nwp,
315+
)
316+
317+
# Shard after we have the loc-times. These are already shuffled so no need to shuffle again
318+
location_pipe = location_pipe.sharding_filter()
319+
t0_datapipe = t0_datapipe.sharding_filter()
320+
321+
# In this function we re-open the datasets to make a clean separation before/after sharding
322+
# This function
323+
datapipe_dict = construct_sliced_data_pipeline(
324+
config_filename,
325+
location_pipe,
326+
t0_datapipe,
327+
block_sat,
328+
block_nwp,
329+
)
330+
331+
# Save out datapipe to NetCDF
332+
333+
# Merge all the datapipes into one
334+
return DictDatasetIterDataPipe(
335+
datapipe_dict["gsp"],
336+
datapipe_dict["nwp"],
337+
datapipe_dict["sat"],
338+
datapipe_dict["pv"],
339+
keys=["gsp", "nwp", "sat", "pv"],
340+
).map(combine_to_single_dataset)
341+
342+
343+
def split_dataset_dict_dp(element):
344+
"""
345+
Split the dictionary of datapipes into individual datapipes
346+
347+
Args:
348+
element: Dictionary of datapipes
349+
"""
350+
return {k: IterableWrapper([v]) for k, v in element.items() if k != "config"}
351+
352+
353+
def windnet_netcdf_datapipe(
354+
config_filename: str,
355+
keys: List[str],
356+
filenames: List[str],
357+
block_sat: bool = False,
358+
block_nwp: bool = False,
359+
) -> IterDataPipe:
360+
"""
361+
Load the saved Datapipes from windnet, and transform to numpy batch
362+
363+
Args:
364+
config_filename: Path to config file.
365+
keys: List of keys to extract from the single NetCDF files
366+
filenames: List of NetCDF files to load
367+
block_sat: Whether to load zeroes for satellite data.
368+
block_nwp: Whether to load zeroes for NWP data.
369+
370+
Returns:
371+
Datapipe that transforms the NetCDF files to numpy batch
372+
"""
373+
logger.info("Constructing windnet file pipeline")
374+
config_datapipe = OpenConfiguration(config_filename)
375+
configuration: Configuration = next(iter(config_datapipe))
376+
# Load files
377+
datapipe_dict_dp: IterDataPipe = LoadDictDatasetIterDataPipe(
378+
filenames=filenames,
379+
keys=keys,
380+
).map(split_dataset_dict_dp)
381+
datapipe = datapipe_dict_dp.convert_to_numpy_batch(
382+
block_nwp=block_nwp, block_sat=block_sat, configuration=configuration
383+
)
384+
385+
return datapipe

0 commit comments

Comments
 (0)