From f58917cbb910e2274b88a71c6203073eaabe2318 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Fri, 10 Nov 2023 05:29:15 +0000 Subject: [PATCH 01/29] Add ability to join and break apart Datasets --- ocf_datapipes/training/windnet.py | 865 ++++++++++++++++++++++++++++++ ocf_datapipes/utils/utils.py | 67 +++ 2 files changed, 932 insertions(+) create mode 100644 ocf_datapipes/training/windnet.py diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py new file mode 100644 index 000000000..dfa28b7af --- /dev/null +++ b/ocf_datapipes/training/windnet.py @@ -0,0 +1,865 @@ +"""Create the training/validation datapipe for training the PVNet Model""" +import logging +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import xarray as xr +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe + +from ocf_datapipes.batch import MergeNumpyModalities +from ocf_datapipes.config.model import Configuration +from ocf_datapipes.load import OpenGSPFromDatabase, OpenPVFromPVSitesDB +from ocf_datapipes.training.common import ( + create_t0_and_loc_datapipes, + open_and_return_datapipes, +) +from ocf_datapipes.utils.consts import ( + NEW_NWP_MEAN, + NEW_NWP_STD, + RSS_MEAN, + RSS_STD, + BatchKey, + NumpyBatch, +) +from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_netcdf + +xr.set_options(keep_attrs=True) +logger = logging.getLogger("pvnet_datapipe") + + +def normalize_gsp(x): + """Normalize the GSP data + + Args: + x: Input DataArray + + Returns: + Normalized DataArray + """ + return x / x.effective_capacity_mwp + + +def normalize_pv(x): + """Normalize the PV data + + Args: + x: Input DataArray + + Returns: + Normalized DataArray + """ + return (x / x.nominal_capacity_wp).clip(None, 5) + + +def production_sat_scale(x): + """Scale the production satellite data + + Args: + x: Input DataArray + + Returns: + Scaled DataArray + """ + return x / 1024 + + +def concat_xr_time_utc(gsp_dataarrays: List[xr.DataArray]): + """This function is used to combine the split history and future gsp/pv dataarrays. + + These are split inside the `slice_datapipes_by_time()` function below. + + Splitting them inside that function allows us to apply dropout to the + history GSP/PV whilst leaving the future GSP/PV without NaNs. + + We recombine the history and future with this function to allow us to use the + `MergeNumpyModalities()` datapipe without redefining the BatchKeys. + + The `pvnet` model was also written to use a GSP/PV array which has historical and future + and to split it out. These maintains that assumption. + """ + return xr.concat(gsp_dataarrays, dim="time_utc") + + +def gsp_drop_national(x: Union[xr.DataArray, xr.Dataset]): + """Drop entries for national PV output + + Args: + x: Data source of gsp data + + Returns: + Filtered data source + """ + return x.where(x.gsp_id != 0, drop=True) + + +@functional_datapipe("pvnet_select_pv_by_ml_id") +class PVNetSelectPVbyMLIDIterDataPipe(IterDataPipe): + """Select specific set of PV systems by ML ID.""" + + def __init__(self, source_datapipe: IterDataPipe, ml_ids: np.array): + """Select specific set of PV systems by ML ID. + + Args: + source_datapipe: Datapipe emitting PV xarray data + ml_ids: List-like of ML IDs to select + + Returns: + Filtered data source + """ + self.source_datapipe = source_datapipe + self.ml_ids = ml_ids + + def __iter__(self): + for x in self.source_datapipe: + # Check for missing IDs + ml_ids_not_in_data = ~np.isin(self.ml_ids, x.ml_id) + if ml_ids_not_in_data.any(): + missing_ml_ids = np.array(self.ml_ids)[ml_ids_not_in_data] + logger.warning( + f"The following ML IDs were mising in the PV site-level input data: " + f"{missing_ml_ids}. The values for these IDs will be set to NaN." + ) + + x_filtered = ( + # Many ML-IDs are null, so filter first + x.where(~x.ml_id.isnull(), drop=True) + # Swap dimensions so we can select by ml_id coordinate + .swap_dims({"pv_system_id": "ml_id"}) + # Select IDs - missing IDs are given NaN values + .reindex(ml_id=self.ml_ids) + # Swap back dimensions + .swap_dims({"ml_id": "pv_system_id"}) + ) + yield x_filtered + + +def fill_nans_in_pv(x: Union[xr.DataArray, xr.Dataset]): + """Fill NaNs in PV data with the value -1 + + Args: + x: Input DataArray + + Returns: + Normalized DataArray + """ + return x.fillna(-1) + + +def scale_wind_speed_to_power(x: Union[xr.DataArray, xr.Dataset]): + """ + Scale wind speed to power to estimate the generation of wind power from ground sensors + + Roughly, double speed in m/s, and convert with the power scale + + Args: + x: + + Returns: + + """ + # Convert knots to m/s + x = x * 0.514444 + # Roughly double speed to get power + x = x * 2 + return x + + +def fill_nans_in_arrays(batch: NumpyBatch) -> NumpyBatch: + """Fills all NaN values in each np.ndarray in the batch dictionary with zeros. + + Operation is performed in-place on the batch. + """ + logger.info("Filling Nans with zeros") + for k, v in batch.items(): + if isinstance(v, np.ndarray): + np.nan_to_num(v, copy=False, nan=0.0) + return batch + + +class AddZeroedSatelliteData: + """A callable class used to add zeroed-out satellite data to batches of data. + + This is useful + to speed up batch loading if pre-training the output part of the network without satellite + inputs. + """ + + def __init__(self, configuration: Configuration, is_hrv: bool = False): + """A callable class used to add zeroed-out satellite data to batches of data. + + Args: + configuration: Configuration object + is_hrv: If False, non-HRV data is added by called function, else HRV. + """ + + self.configuration = configuration + self.is_hrv = is_hrv + + def __call__(self, batch: NumpyBatch) -> NumpyBatch: + """Add zeroed-out satellite data to batch with shape accoriding to supplied configuration. + + Batch is modified in-place and returned. + + Args: + batch: Numpy batch of input data. + """ + + variable = "hrvsatellite" if self.is_hrv else "satellite" + + satellite_config = getattr(self.configuration.input_data, variable) + + n_channels = len(getattr(satellite_config, f"{variable}_channels")) + height = getattr(satellite_config, f"{variable}_image_size_pixels_height") + width = getattr(satellite_config, f"{variable}_image_size_pixels_width") + + sequence_len = satellite_config.history_minutes // 5 + 1 - 3 + + batch[getattr(BatchKey, f"{variable}_actual")] = np.zeros( + (sequence_len, n_channels, height, width) + ) + + return batch + + +class AddZeroedNWPData: + """A callable class used to add zeroed-out NWP data to batches of data. + + This is useful to speed up batch loading if pre-training the output part of the network without + NWP inputs. + """ + + def __init__(self, configuration: Configuration): + """A callable class used to add zeroed-out NWP data to batches of data. + + Args: + configuration: Configuration object + """ + self.configuration = configuration + + def __call__(self, batch: NumpyBatch) -> NumpyBatch: + """Add zeroed-out NWP data to batch with shape accoriding to supplied configuration. + + Batch is modified in-place and returned. + + Args: + batch: Numpy batch of input data. + """ + + config = self.configuration.input_data.nwp + + n_channels = len(config.nwp_channels) + height = config.nwp_image_size_pixels_height + width = config.nwp_image_size_pixels_width + + sequence_len = config.history_minutes // 60 + config.forecast_minutes // 60 + 1 + + batch[BatchKey.nwp] = np.zeros((sequence_len, n_channels, height, width)) + + return batch + + +class DatapipeKeyForker: + """ "Internal helper function to track forking of a datapipe.""" + + def __init__(self, keys: List, datapipe: IterDataPipe): + """Internal helper function to track forking of a datapipe. + + As forks are returned, this object tracks the keys left and returns the final copy of the + datapipe when the last key is requested. This makes multiple forking easier and ensures + closure. + + Args: + keys: List of keys for which datapipe duplication is required. + datapipe: Datapipe which will be forked for each ket + """ + self.keys_left = keys + self.datapipe = datapipe + + def __call__(self, key): + """ "Returns a fork of `self.datapipe` and tracks a the keys left to ensure closure. + + Args: + key: key to remove from `self.keys_left`. If `key` is None then an extra copy is made + without affecting `self.keys_left`. + """ + if len(self.keys_left) == 0: + raise ValueError(f"No keys left when requested key : {key}") + if key is not None: + self.keys_left.remove(key) + if len(self.keys_left) > 0: + self.datapipe, return_datapipe = self.datapipe.fork(2, buffer_size=5) + else: + return_datapipe = self.datapipe + return return_datapipe + + def close(self): + """Asserts that the keys have all been used.""" + assert len(self.keys_left) == 0 + + +@functional_datapipe("dict_datasets") +class DictDatasetIterDataPipe(IterDataPipe): + """ """ + + datapipes: Tuple[IterDataPipe] + length: Optional[int] + + def __init__(self, *datapipes: IterDataPipe, keys: List[str]): + """Init""" + if not all(isinstance(dp, IterDataPipe) for dp in datapipes): + raise TypeError( + "All inputs are required to be `IterDataPipe` " "for `ZipIterDataPipe`." + ) + super().__init__() + self.keys = keys + self.datapipes = datapipes # type: ignore[assignment] + self.length = None + assert len(self.keys) == len(self.datapipes), "Number of keys must match number of pipes" + + def __iter__(self): + """Iter""" + iterators = [iter(datapipe) for datapipe in self.datapipes] + for data in zip(*iterators): + # Yield a dictionary of the data, using the keys in self.keys + yield {k: v for k, v in zip(self.keys, data)} + + +def _get_datapipes_dict( + config_filename: str, + block_sat: bool, + block_nwp: bool, + production: bool = False, +): + # Load datasets + datapipes_dict = open_and_return_datapipes( + configuration_filename=config_filename, + use_gsp=(not production), + use_pv=(not production), + use_sat=not block_sat, # Only loaded if we aren't replacing them with zeros + use_hrv=False, + use_nwp=not block_nwp, # Only loaded if we aren't replacing them with zeros + use_topo=False, + production=production, + ) + + config: Configuration = datapipes_dict["config"] + + if production: + datapipes_dict["gsp"] = OpenGSPFromDatabase().add_t0_idx_and_sample_period_duration( + sample_period_duration=timedelta(minutes=30), + history_duration=timedelta(minutes=config.input_data.gsp.history_minutes), + ) + if "sat" in datapipes_dict: + datapipes_dict["sat"] = datapipes_dict["sat"].map(production_sat_scale) + if "pv" in datapipes_dict: + datapipes_dict["pv"] = OpenPVFromPVSitesDB(config.input_data.pv.history_minutes) + + if "pv" in datapipes_dict and config.input_data.pv.pv_ml_ids != []: + datapipes_dict["pv"] = datapipes_dict["pv"].pvnet_select_pv_by_ml_id( + config.input_data.pv.pv_ml_ids + ) + + return datapipes_dict + + +def construct_loctime_pipelines( + config_filename: str, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + block_sat: bool = False, + block_nwp: bool = False, +) -> Tuple[IterDataPipe, IterDataPipe]: + """Construct location and time pipelines for the input data config file. + + Args: + config_filename: Path to config file. + start_time: Minimum time for time datapipe. + end_time: Maximum time for time datapipe. + block_sat: Whether to load zeroes for satellite data. + block_nwp: Whether to load zeroes for NWP data. + """ + + datapipes_dict = _get_datapipes_dict( + config_filename, + block_sat, + block_nwp, + ) + + # Pull out config file + config = datapipes_dict.pop("config") + + # We sample time and space of other data using GSP time and space coordinates, so filter GSP + # data first amd this is carried through + # Map from wind speed to m/s here + datapipes_dict["gsp"] = datapipes_dict["gsp"] + if (start_time is not None) or (end_time is not None): + datapipes_dict["gsp"] = datapipes_dict["gsp"].select_train_test_time(start_time, end_time) + + # Get overlapping time periods + location_pipe, t0_datapipe = create_t0_and_loc_datapipes( + datapipes_dict, + configuration=config, + key_for_t0="gsp", + shuffle=True, + nwp_max_dropout_minutes=180, + ) + + return location_pipe, t0_datapipe + + +def minutes(num_mins: int): + """Timedelta of a number of minutes. + + Args: + num_mins: Minutes timedelta. + """ + return timedelta(minutes=num_mins) + + +def slice_datapipes_by_time( + datapipes_dict: Dict, + t0_datapipe: IterDataPipe, + configuration: Configuration, + production: bool = False, +) -> None: + """ + Modifies a dictionary of datapipes in-place to yield samples for given times t0. + + The NWP data* will be at least 90 minutes stale (i.e. as if it takes 90 minutes for the foreast + to become available). + + The satellite data* is shaped so that the most recent can be 15 minutes before t0. However, 50% + of the time dropout is applied so that the most recent field is between 45 and 20 minutes before + t0. When dropped out like this, the values after this selected dropout time are set to NaN. + + The HRV data* is similar to the satellite data and if both are included they drop out + simulataneously. + + The GSP data is split into "gsp" and "gsp_future" keys. 10% of the time the gsp value for time + t0, which occurs under the "gsp" key, is set to NaN + + The PV data* is also split it "pv" and "pv_future" keys. + + * if included + + n.b. PV and HRV are included in this function, but not yet in the rest of the pvnet pipeline. + This is mostly for demonstratio purposes of how the dropout might be applied. + + Args: + datapipes_dict: Dictionary of used datapipes and t0 ones + t0_datapipe: Datapipe which yields t0 times for sample + configuration: Configuration object. + production: Whether constucting pipeline for production inference. No dropout is used if + True. + + """ + + conf_in = configuration.input_data + + # Use DatapipeKeyForker to avoid forking t0_datapipe too many times, or leaving any forks unused + fork_keys = {k for k in datapipes_dict.keys() if k not in ["topo"]} + get_t0_datapipe = DatapipeKeyForker(fork_keys, t0_datapipe) + + sat_and_hrv_dropout_kwargs = dict( + # Satellite is either 30 minutes or 60 minutes delayed in production. Match during training + dropout_timedeltas=[minutes(-60), minutes(-30)], + dropout_frac=0 if production else 1.0, + ) + + sat_delay = minutes(-configuration.input_data.satellite.live_delay_minutes) + + if "nwp" in datapipes_dict: + datapipes_dict["nwp"] = datapipes_dict["nwp"].convert_to_nwp_target_time_with_dropout( + t0_datapipe=get_t0_datapipe("nwp"), + sample_period_duration=minutes(60), + history_duration=minutes(conf_in.nwp.history_minutes), + forecast_duration=minutes(conf_in.nwp.forecast_minutes), + # The NWP forecast will always be at least 180 minutes stale + dropout_timedeltas=[minutes(-180)], + dropout_frac=0 if production else 1.0, + ) + + if "sat" in datapipes_dict: + # Take time slices of sat data + datapipes_dict["sat"] = datapipes_dict["sat"].select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(5), + interval_start=minutes(-conf_in.satellite.history_minutes), + interval_end=sat_delay, + fill_selection=production, + max_steps_gap=2, + ) + + # Generate randomly sampled dropout times + sat_dropout_time_datapipe = get_t0_datapipe("sat").select_dropout_time( + **sat_and_hrv_dropout_kwargs + ) + + if "hrv" in datapipes_dict: + # Make dropout-time copy for hrv if included in data. + # HRV and non-HRV will dropout simultaneously. + sat_dropout_time_datapipe, hrv_dropout_time_datapipe = sat_dropout_time_datapipe.fork( + 2, buffer_size=5 + ) + + # Apply the dropout + datapipes_dict["sat"] = datapipes_dict["sat"].apply_dropout_time( + dropout_time_datapipe=sat_dropout_time_datapipe, + ) + + if "hrv" in datapipes_dict: + if "sat" not in datapipes_dict: + # Generate randomly sampled dropout times + # This is shared with sat if sat included + hrv_dropout_time_datapipe = get_t0_datapipe(None).select_dropout_time( + **sat_and_hrv_dropout_kwargs + ) + + datapipes_dict["hrv"] = datapipes_dict["hrv"].select_time_slice( + t0_datapipe=get_t0_datapipe("hrv"), + sample_period_duration=minutes(5), + interval_start=minutes(-conf_in.hrvsatellite.history_minutes), + interval_end=sat_delay, + fill_selection=production, + max_steps_gap=2, + ) + + # Apply the dropout + datapipes_dict["hrv"] = datapipes_dict["hrv"].apply_dropout_time( + dropout_time_datapipe=hrv_dropout_time_datapipe, + ) + + if "pv" in datapipes_dict: + datapipes_dict["pv"], dp = datapipes_dict["pv"].fork(2, buffer_size=5) + + datapipes_dict["pv_future"] = dp.select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(5), + interval_start=minutes(5), + interval_end=minutes(conf_in.pv.forecast_minutes), + fill_selection=production, + ) + + datapipes_dict["pv"] = datapipes_dict["pv"].select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(5), + interval_start=minutes(-conf_in.pv.history_minutes), + interval_end=minutes(0), + fill_selection=production, + ) + + # Dropout on the PV, but not the future PV + pv_dropout_time_datapipe = get_t0_datapipe("pv").select_dropout_time( + # All PV data could be delayed by up to 30 minutes + # (this does not stem from production - just setting for now) + dropout_timedeltas=[minutes(m) for m in range(-30, 0, 5)], + dropout_frac=0.1 if production else 1, + ) + + datapipes_dict["pv"] = datapipes_dict["pv"].apply_dropout_time( + dropout_time_datapipe=pv_dropout_time_datapipe, + ) + + # Apply extra PV dropout using different delays per system and droping out entire PV systems + # independently + if not production: + datapipes_dict["pv"].apply_pv_dropout( + system_dropout_fractions=np.linspace(0, 0.2, 100), + system_dropout_timedeltas=[minutes(m) for m in [-15, -10, -5, 0]], + ) + + if "gsp" in datapipes_dict: + datapipes_dict["gsp"], dp = datapipes_dict["gsp"].fork(2, buffer_size=5) + + datapipes_dict["gsp_future"] = dp.select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(30), + interval_start=minutes(30), + interval_end=minutes(conf_in.gsp.forecast_minutes), + fill_selection=production, + ) + + datapipes_dict["gsp"] = datapipes_dict["gsp"].select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(30), + interval_start=-minutes(conf_in.gsp.history_minutes), + interval_end=minutes(0), + fill_selection=production, + ) + + # Dropout on the GSP, but not the future GSP + gsp_dropout_time_datapipe = get_t0_datapipe("gsp").select_dropout_time( + # GSP data for time t0 may be missing. Only have value for t0-30mins + dropout_timedeltas=[minutes(-30)], + dropout_frac=0 if production else 0.1, + ) + + datapipes_dict["gsp"] = datapipes_dict["gsp"].apply_dropout_time( + dropout_time_datapipe=gsp_dropout_time_datapipe, + ) + + get_t0_datapipe.close() + + return + + +def construct_sliced_data_pipeline( + config_filename: str, + location_pipe: IterDataPipe, + t0_datapipe: IterDataPipe, + block_sat: bool = False, + block_nwp: bool = False, + production: bool = False, +) -> dict: + """Constructs data pipeline for the input data config file. + + This yields samples from the location and time datapipes. + + Args: + config_filename: Path to config file. + location_pipe: Datapipe yielding locations. + t0_datapipe: Datapipe yielding times. + block_sat: Whether to load zeroes for satellite data. + block_nwp: Whether to load zeroes for NWP data. + production: Whether constucting pipeline for production inference. + check_satellite_no_zeros: Whether to check that satellite data has no zeros. + """ + + assert not (production and (block_sat or block_nwp)) + + datapipes_dict = _get_datapipes_dict( + config_filename, + block_sat, + block_nwp, + production=production, + ) + + configuration = datapipes_dict.pop("config") + + # Unpack for convenience + conf_sat = configuration.input_data.satellite + conf_nwp = configuration.input_data.nwp + + # Slice all of the datasets by time - this is an in-place operation + slice_datapipes_by_time(datapipes_dict, t0_datapipe, configuration, production) + + if "nwp" in datapipes_dict: + nwp_datapipe = datapipes_dict["nwp"] + + location_pipe, location_pipe_copy = location_pipe.fork(2, buffer_size=5) + nwp_datapipe = nwp_datapipe.select_spatial_slice_pixels( + location_pipe_copy, + roi_height_pixels=conf_nwp.nwp_image_size_pixels_height, + roi_width_pixels=conf_nwp.nwp_image_size_pixels_width, + ) + nwp_datapipe = nwp_datapipe.normalize(mean=NEW_NWP_MEAN, std=NEW_NWP_STD) + + if "sat" in datapipes_dict: + sat_datapipe = datapipes_dict["sat"] + + location_pipe, location_pipe_copy = location_pipe.fork(2, buffer_size=5) + sat_datapipe = sat_datapipe.select_spatial_slice_pixels( + location_pipe_copy, + roi_height_pixels=conf_sat.satellite_image_size_pixels_height, + roi_width_pixels=conf_sat.satellite_image_size_pixels_width, + ) + sat_datapipe = sat_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD) + + if "pv" in datapipes_dict: + # Recombine PV arrays - see function doc for further explanation + pv_datapipe = ( + datapipes_dict["pv"].zip_ocf(datapipes_dict["pv_future"]).map(concat_xr_time_utc) + ) + pv_datapipe = pv_datapipe.normalize(normalize_fn=normalize_pv) + pv_datapipe = pv_datapipe.map(fill_nans_in_pv) + + # GSP always assumed to be in data + location_pipe, location_pipe_copy = location_pipe.fork(2, buffer_size=5) + gsp_future_datapipe = datapipes_dict["gsp_future"] + gsp_future_datapipe = gsp_future_datapipe.select_spatial_slice_meters( + location_datapipe=location_pipe_copy, + roi_height_meters=1, + roi_width_meters=1, + dim_name="gsp_id", + ) + + gsp_datapipe = datapipes_dict["gsp"] + gsp_datapipe = gsp_datapipe.select_spatial_slice_meters( + location_datapipe=location_pipe, + roi_height_meters=1, + roi_width_meters=1, + dim_name="gsp_id", + ) + + # Recombine GSP arrays - see function doc for further explanation + gsp_datapipe = gsp_datapipe.zip_ocf(gsp_future_datapipe).map(concat_xr_time_utc) + gsp_datapipe = gsp_datapipe.normalize(normalize_fn=normalize_gsp) + + finished_dataset_dict = {"gsp": gsp_datapipe, "config": configuration} + if "nwp" in datapipes_dict: + finished_dataset_dict["nwp"] = nwp_datapipe + if "sat" in datapipes_dict: + finished_dataset_dict["sat"] = sat_datapipe + if "pv" in datapipes_dict: + finished_dataset_dict["pv"] = pv_datapipe + + return finished_dataset_dict + + +def convert_to_numpy_batch( + datapipes_dict: dict, + block_sat: bool = False, + block_nwp: bool = False, + check_satellite_no_zeros: bool = False, +): + configuration = datapipes_dict["config"] + # Spatially slice, normalize, and convert data to numpy arrays + numpy_modalities = [] + # Unpack for convenience + conf_sat = configuration.input_data.satellite + conf_nwp = configuration.input_data.nwp + if "nwp" in datapipes_dict: + numpy_modalities.append(datapipes_dict["nwp"].convert_nwp_to_numpy_batch()) + if "sat" in datapipes_dict: + numpy_modalities.append(datapipes_dict["sat"].convert_satellite_to_numpy_batch()) + if "pv" in datapipes_dict: + numpy_modalities.append(datapipes_dict["pv"].convert_pv_to_numpy_batch()) + numpy_modalities.append(datapipes_dict["gsp"].convert_gsp_to_numpy_batch()) + + logger.debug("Combine all the data sources") + combined_datapipe = MergeNumpyModalities(numpy_modalities).add_sun_position(modality_name="gsp") + + if block_sat and conf_sat != "": + sat_block_func = AddZeroedSatelliteData(configuration) + combined_datapipe = combined_datapipe.map(sat_block_func) + + if block_nwp and conf_nwp != "": + nwp_block_func = AddZeroedNWPData(configuration) + combined_datapipe = combined_datapipe.map(nwp_block_func) + + logger.info("Filtering out samples with no data") + if check_satellite_no_zeros: + # in production we don't want any nans in the satellite data + combined_datapipe = combined_datapipe.map(check_nans_in_satellite_data) + + combined_datapipe = combined_datapipe.map(fill_nans_in_arrays) + + return combined_datapipe + + +def write_to_netcdf(datapipes_dict): + """ + Write the batch to a netcdf file. + """ + dataset = combine_to_single_dataset(datapipes_dict) + print(dataset) + + +def windnet_datapipe( + config_filename: str, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + block_sat: bool = False, + block_nwp: bool = False, +) -> IterDataPipe: + """ + Construct windnet pipeline for the input data config file. + + Args: + config_filename: Path to config file. + start_time: Minimum time at which a sample can be selected. + end_time: Maximum time at which a sample can be selected. + block_sat: Whether to load zeroes for satellite data. + block_nwp: Whether to load zeroes for NWP data. + """ + logger.info("Constructing windnet pipeline") + + # Open datasets from the config and filter to useable location-time pairs + location_pipe, t0_datapipe = construct_loctime_pipelines( + config_filename, + start_time, + end_time, + block_sat, + block_nwp, + ) + + # Shard after we have the loc-times. These are already shuffled so no need to shuffle again + location_pipe = location_pipe.sharding_filter() + t0_datapipe = t0_datapipe.sharding_filter() + + # In this function we re-open the datasets to make a clean separation before/after sharding + # This function + datapipe_dict = construct_sliced_data_pipeline( + config_filename, + location_pipe, + t0_datapipe, + block_sat, + block_nwp, + ) + + # Save out datapipe to NetCDF + + # Convert to numpy batch + # datapipe = convert_to_numpy_batch( + # datapipe_dict, + # block_sat, + # block_nwp, + # ) + # Merge all the datapipes into one + return DictDatasetIterDataPipe( + datapipe_dict["gsp"], + datapipe_dict["nwp"], + datapipe_dict["sat"], + datapipe_dict["pv"], + keys=["gsp", "nwp", "sat", "pv"], + ) + + +def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: + """ + Check if there are any Nans values in the satellite data. + """ + if np.any(np.isnan(batch[BatchKey.satellite_actual])): + logger.error("Found nans values in satellite data") + + logger.error(batch[BatchKey.satellite_actual].shape) + + # loop over time and channels + for dim in [0, 1]: + for t in range(batch[BatchKey.satellite_actual].shape[dim]): + if dim == 0: + sate_data_one_step = batch[BatchKey.satellite_actual][t] + else: + sate_data_one_step = batch[BatchKey.satellite_actual][:, t] + nans = np.isnan(sate_data_one_step) + + if np.any(nans): + percent_nans = np.sum(nans) / np.prod(sate_data_one_step.shape) * 100 + + logger.error( + f"Found nans values in satellite data at index {t} ({dim=}). " + f"{percent_nans}% of values are nans" + ) + else: + logger.error(f"Found no nans values in satellite data at index {t} {dim=}") + + raise ValueError("Found nans values in satellite data") + + return batch + + +if __name__ == "__main__": + start_time = datetime(1900, 1, 1) + end_time = datetime(2050, 1, 1) + configuration_filename = "/home/jacob/Development/ocf_datapipes/tests/config/test.yaml" + dp = windnet_datapipe( + configuration_filename, + start_time=start_time, + end_time=end_time, + ) + datasets = next(iter(dp)) + dataset = combine_to_single_dataset(datasets) + multiple_datasets = uncombine_from_netcdf(dataset) + print(multiple_datasets) diff --git a/ocf_datapipes/utils/utils.py b/ocf_datapipes/utils/utils.py index edbbf907c..d02f7ebf5 100644 --- a/ocf_datapipes/utils/utils.py +++ b/ocf_datapipes/utils/utils.py @@ -329,3 +329,70 @@ def trigonometric_datetime_transformation(datetimes: npt.ArrayLike) -> np.ndarra return np.concatenate( [sine_month, cosine_month, sine_day, cosine_day, sine_hour, cosine_hour], axis=1 ) + + +def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset: + """ + Combine multiple datasets into a single dataset + + Args: + *datasets: Datasets to combine + + Returns: + Combined dataset + """ + # Convert all data_arrays to datasets + for key, datasets in dataset_dict.items(): + new_datasets = [] + for dataset in datasets: + if isinstance(dataset, xr.DataArray): + new_datasets.append(dataset.to_dataset(name=key)) + else: + new_datasets.append(dataset) + dataset_dict[key] = new_datasets + # Prepend all coordinates and dimensions names with the key in the dataset_dict + final_datasets_to_combined = [] + for key, datasets in dataset_dict.items(): + batched_datasets = [] + for dataset in datasets: + dataset = dataset.rename( + {dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords} + ) + dataset = dataset.rename({coord: f"{key}__{coord}" for coord in dataset.coords}) + batched_datasets.append(dataset) + # Merge all datasets with the same key + dataset = xr.concat(batched_datasets, dim=f"{key}__time_utc") + final_datasets_to_combined.append(dataset) + # Combine all datasets, and append the list of datasets to the dataset_dict + combined_dataset = xr.merge(final_datasets_to_combined) + return combined_dataset + + +def uncombine_from_netcdf(combined_dataset: xr.Dataset) -> dict[str, xr.DataArray]: + """ + Uncombine a combined dataset + + Args: + combined_dataset: The combined NetCDF dataset + + Returns: + The uncombined datasets as a dict of xr.Datasets + """ + # Split into datasets by splitting by the prefix added in combine_to_netcdf + datasets = {} + # Go through each data variable and split it into a dataset + for key, dataset in combined_dataset.items(): + # If 'key_' doesn't exist in a dim or coordinate, remove it + dataset_dims = list(dataset.coords) + for dim in dataset_dims: + if f"{key}__" not in dim: + dataset = dataset.drop(dim) + # print(dataset) + dataset = dataset.rename( + {dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords} + ) + # print(dataset) + dataset = dataset.rename({coord: coord.split(f"{key}__")[1] for coord in dataset.coords}) + # Split the dataset by the prefix + datasets[key] = dataset + return datasets From a3435a6ba48dd95c32b5b3c132c7dd34a5ca5954 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 14 Nov 2023 12:17:40 +0000 Subject: [PATCH 02/29] Add test of combine/uncombine --- ocf_datapipes/training/windnet.py | 4 ++-- ocf_datapipes/utils/utils.py | 2 +- tests/utils/test_utils.py | 18 ++++++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index dfa28b7af..8ff47aa0e 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -23,7 +23,7 @@ BatchKey, NumpyBatch, ) -from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_netcdf +from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset xr.set_options(keep_attrs=True) logger = logging.getLogger("pvnet_datapipe") @@ -861,5 +861,5 @@ def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: ) datasets = next(iter(dp)) dataset = combine_to_single_dataset(datasets) - multiple_datasets = uncombine_from_netcdf(dataset) + multiple_datasets = uncombine_from_single_dataset(dataset) print(multiple_datasets) diff --git a/ocf_datapipes/utils/utils.py b/ocf_datapipes/utils/utils.py index d02f7ebf5..3c3c43977 100644 --- a/ocf_datapipes/utils/utils.py +++ b/ocf_datapipes/utils/utils.py @@ -368,7 +368,7 @@ def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset return combined_dataset -def uncombine_from_netcdf(combined_dataset: xr.Dataset) -> dict[str, xr.DataArray]: +def uncombine_from_single_dataset(combined_dataset: xr.Dataset) -> dict[str, xr.DataArray]: """ Uncombine a combined dataset diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index d6792559b..4978dc6b7 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -1,5 +1,8 @@ import numpy as np from ocf_datapipes.utils.utils import searchsorted +from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset +from ocf_datapipes.training.windnet import windnet_datapipe +from datetime import datetime def test_searchsorted(): @@ -7,3 +10,18 @@ def test_searchsorted(): assert searchsorted(ys, 2.1) == 2 ys_r = np.array([5, 4, 3, 2, 1], dtype=np.float32) assert searchsorted(ys_r, 2.1, assume_ascending=False) == 3 + + +def test_combine_uncombine_from_single_dataset(configuration_filename): + start_time = datetime(1900, 1, 1) + end_time = datetime(2050, 1, 1) + dp = windnet_datapipe( + configuration_filename, + start_time=start_time, + end_time=end_time, + ) + datasets = next(iter(dp)) + dataset = combine_to_single_dataset(datasets) + multiple_datasets = uncombine_from_single_dataset(dataset) + for key in multiple_datasets.keys(): + assert datasets[key].equals(multiple_datasets[key]) From e77ed5848cc3ed4f86c512b7f02b1cb98ed4240e Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 14 Nov 2023 12:29:15 +0000 Subject: [PATCH 03/29] Add another assert --- tests/utils/test_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 4978dc6b7..c6adb6be1 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -3,6 +3,7 @@ from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset from ocf_datapipes.training.windnet import windnet_datapipe from datetime import datetime +import xarray as xr def test_searchsorted(): @@ -21,7 +22,8 @@ def test_combine_uncombine_from_single_dataset(configuration_filename): end_time=end_time, ) datasets = next(iter(dp)) - dataset = combine_to_single_dataset(datasets) + dataset: xr.Dataset = combine_to_single_dataset(datasets) + assert isinstance(dataset, xr.Dataset) multiple_datasets = uncombine_from_single_dataset(dataset) for key in multiple_datasets.keys(): assert datasets[key].equals(multiple_datasets[key]) From b3f91815cd80111af2f36e7753d7465fd752181c Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 07:35:25 +0000 Subject: [PATCH 04/29] Fix dual import --- ocf_datapipes/training/windnet.py | 142 +++++++++++++++++++----------- 1 file changed, 91 insertions(+), 51 deletions(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index 8ff47aa0e..18b6054f7 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -6,7 +6,7 @@ import numpy as np import xarray as xr from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe +from torchdata.datapipes.iter import IterDataPipe, IterableWrapper from ocf_datapipes.batch import MergeNumpyModalities from ocf_datapipes.config.model import Configuration @@ -23,6 +23,9 @@ BatchKey, NumpyBatch, ) +from ocf_datapipes.load import ( + OpenConfiguration, +) from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset xr.set_options(keep_attrs=True) @@ -94,47 +97,6 @@ def gsp_drop_national(x: Union[xr.DataArray, xr.Dataset]): return x.where(x.gsp_id != 0, drop=True) -@functional_datapipe("pvnet_select_pv_by_ml_id") -class PVNetSelectPVbyMLIDIterDataPipe(IterDataPipe): - """Select specific set of PV systems by ML ID.""" - - def __init__(self, source_datapipe: IterDataPipe, ml_ids: np.array): - """Select specific set of PV systems by ML ID. - - Args: - source_datapipe: Datapipe emitting PV xarray data - ml_ids: List-like of ML IDs to select - - Returns: - Filtered data source - """ - self.source_datapipe = source_datapipe - self.ml_ids = ml_ids - - def __iter__(self): - for x in self.source_datapipe: - # Check for missing IDs - ml_ids_not_in_data = ~np.isin(self.ml_ids, x.ml_id) - if ml_ids_not_in_data.any(): - missing_ml_ids = np.array(self.ml_ids)[ml_ids_not_in_data] - logger.warning( - f"The following ML IDs were mising in the PV site-level input data: " - f"{missing_ml_ids}. The values for these IDs will be set to NaN." - ) - - x_filtered = ( - # Many ML-IDs are null, so filter first - x.where(~x.ml_id.isnull(), drop=True) - # Swap dimensions so we can select by ml_id coordinate - .swap_dims({"pv_system_id": "ml_id"}) - # Select IDs - missing IDs are given NaN values - .reindex(ml_id=self.ml_ids) - # Swap back dimensions - .swap_dims({"ml_id": "pv_system_id"}) - ) - yield x_filtered - - def fill_nans_in_pv(x: Union[xr.DataArray, xr.Dataset]): """Fill NaNs in PV data with the value -1 @@ -326,6 +288,35 @@ def __iter__(self): yield {k: v for k, v in zip(self.keys, data)} +@functional_datapipe("load_dict_datasets") +class LoadDictDatasetIterDataPipe(IterDataPipe): + """ """ + + filenames: List[str] + keys: List[str] + configuration: Configuration + + def __init__(self, filenames: List[str], keys: List[str], configuration: Configuration): + """Init""" + super().__init__() + self.keys = keys + self.filenames = filenames + self.configuration + + def __iter__(self): + """Iter""" + # Iterate through each filename, loading it, uncombining it, and then yielding it + while True: + for filename in self.filenames: + dataset = xr.open_dataset(filename) + datasets = uncombine_from_single_dataset(dataset) + # Yield a dictionary of the data, using the keys in self.keys + dataset_dict = {} + for k in self.keys: + dataset_dict[k] = datasets[k] + yield dataset_dict + + def _get_datapipes_dict( config_filename: str, block_sat: bool, @@ -356,11 +347,6 @@ def _get_datapipes_dict( if "pv" in datapipes_dict: datapipes_dict["pv"] = OpenPVFromPVSitesDB(config.input_data.pv.history_minutes) - if "pv" in datapipes_dict and config.input_data.pv.pv_ml_ids != []: - datapipes_dict["pv"] = datapipes_dict["pv"].pvnet_select_pv_by_ml_id( - config.input_data.pv.pv_ml_ids - ) - return datapipes_dict @@ -709,12 +695,12 @@ def construct_sliced_data_pipeline( def convert_to_numpy_batch( - datapipes_dict: dict, + datapipes_dict: dict[str, Union[IterDataPipe, Configuration]], block_sat: bool = False, block_nwp: bool = False, check_satellite_no_zeros: bool = False, ): - configuration = datapipes_dict["config"] + configuration: Configuration = datapipes_dict["config"] # Spatially slice, normalize, and convert data to numpy arrays numpy_modalities = [] # Unpack for convenience @@ -817,6 +803,54 @@ def windnet_datapipe( ) +def split_dataset_dict_dp(element): + """ + Split the dictionary of datapipes into individual datapipes + """ + return {k: IterableWrapper([v]) for k, v in element.items() if k != "config"} + + +def windnet_netcdf_datapipe( + config_filename: str, + keys: List[str], + filenames: List[str], + block_sat: bool = False, + block_nwp: bool = False, +) -> IterDataPipe: + """ + Load the saved Datapipes from windnet, and transform to numpy batch + + Args: + config_filename: Path to config file. + keys: List of keys to extract from the single NetCDF files + block_sat: Whether to load zeroes for satellite data. + block_nwp: Whether to load zeroes for NWP data. + + Returns: + Datapipe that transforms the NetCDF files to numpy batch + """ + logger.info("Constructing windnet file pipeline") + config_datapipe = OpenConfiguration(config_filename) + configuration: Configuration = next(iter(config_datapipe)) + # Load files + datapipe_dict_dp: IterDataPipe = LoadDictDatasetIterDataPipe( + filenames=filenames, + keys=keys, + configuration=configuration, + ) + # Split the dataset_dict_dp into dictionary of individual datapipes + datapipe_dict: dict[str:IterDataPipe] = datapipe_dict_dp.map(split_dataset_dict_dp) + + # Convert to numpy batch + datapipe = convert_to_numpy_batch( + datapipe_dict, + block_sat, + block_nwp, + ) + + return datapipe + + def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: """ Check if there are any Nans values in the satellite data. @@ -861,5 +895,11 @@ def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: ) datasets = next(iter(dp)) dataset = combine_to_single_dataset(datasets) - multiple_datasets = uncombine_from_single_dataset(dataset) - print(multiple_datasets) + dataset.to_zarr("test.nc", mode="w", compute=True) + dp = windnet_netcdf_datapipe( + config_filename=configuration_filename, + filenames=["test.zarr"], + keys=["gsp", "nwp", "sat", "pv"], + ) + datasets = next(iter(dp)) + print(datasets) From 904f166886fea45f2a0765bb0e8ddde1c3372bf1 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 10:11:04 +0000 Subject: [PATCH 05/29] Update utils and test to return Datasets instead of DataArray --- ocf_datapipes/utils/utils.py | 4 +++- tests/utils/test_utils.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ocf_datapipes/utils/utils.py b/ocf_datapipes/utils/utils.py index 3c3c43977..275e6baa9 100644 --- a/ocf_datapipes/utils/utils.py +++ b/ocf_datapipes/utils/utils.py @@ -392,7 +392,9 @@ def uncombine_from_single_dataset(combined_dataset: xr.Dataset) -> dict[str, xr. {dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords} ) # print(dataset) - dataset = dataset.rename({coord: coord.split(f"{key}__")[1] for coord in dataset.coords}) + dataset: xr.Dataset = dataset.rename( + {coord: coord.split(f"{key}__")[1] for coord in dataset.coords} + ).to_dataset(name=key) # Split the dataset by the prefix datasets[key] = dataset return datasets diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index c6adb6be1..fb958ac70 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -26,4 +26,5 @@ def test_combine_uncombine_from_single_dataset(configuration_filename): assert isinstance(dataset, xr.Dataset) multiple_datasets = uncombine_from_single_dataset(dataset) for key in multiple_datasets.keys(): - assert datasets[key].equals(multiple_datasets[key]) + for i in range(len(multiple_datasets[key].time_utc)): + assert datasets[key][i].equals(multiple_datasets[key].isel(time_utc=i)) From 6ff4500bd4a55a57e550cb1fb9f5066a87b1bfc2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Nov 2023 07:35:57 +0000 Subject: [PATCH 06/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/training/windnet.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index 18b6054f7..75ffb72e6 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -6,11 +6,15 @@ import numpy as np import xarray as xr from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe, IterableWrapper +from torchdata.datapipes.iter import IterableWrapper, IterDataPipe from ocf_datapipes.batch import MergeNumpyModalities from ocf_datapipes.config.model import Configuration -from ocf_datapipes.load import OpenGSPFromDatabase, OpenPVFromPVSitesDB +from ocf_datapipes.load import ( + OpenConfiguration, + OpenGSPFromDatabase, + OpenPVFromPVSitesDB, +) from ocf_datapipes.training.common import ( create_t0_and_loc_datapipes, open_and_return_datapipes, @@ -23,9 +27,6 @@ BatchKey, NumpyBatch, ) -from ocf_datapipes.load import ( - OpenConfiguration, -) from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset xr.set_options(keep_attrs=True) From f15c935b9f8efc198d15cc023e799acd3e3507da Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 11:43:00 +0000 Subject: [PATCH 07/29] Fixes for using combing and uncombine --- ocf_datapipes/training/windnet.py | 128 ++++++++++++++++-------------- ocf_datapipes/utils/utils.py | 13 ++- 2 files changed, 79 insertions(+), 62 deletions(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index 75ffb72e6..068726038 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -302,7 +302,7 @@ def __init__(self, filenames: List[str], keys: List[str], configuration: Configu super().__init__() self.keys = keys self.filenames = filenames - self.configuration + self.configuration = configuration def __iter__(self): """Iter""" @@ -318,6 +318,65 @@ def __iter__(self): yield dataset_dict +@functional_datapipe("convert_to_numpy_batch") +class ConvertToNumpyBatchIterDataPipe(IterDataPipe): + """ """ + + def __init__( + self, + dataset_dict_dp: IterDataPipe, + configuration: Configuration, + block_sat: bool = False, + block_nwp: bool = False, + check_satellite_no_zeros: bool = False, + ): + """Init""" + super().__init__() + self.dataset_dict_dp = dataset_dict_dp + self.configuration = configuration + self.block_sat = block_sat + self.block_nwp = block_nwp + self.check_satellite_no_zeros = check_satellite_no_zeros + + def __iter__(self): + """Iter""" + for datapipes_dict in self.dataset_dict_dp: + # Spatially slice, normalize, and convert data to numpy arrays + numpy_modalities = [] + # Unpack for convenience + conf_sat = self.configuration.input_data.satellite + conf_nwp = self.configuration.input_data.nwp + if "nwp" in datapipes_dict: + numpy_modalities.append(datapipes_dict["nwp"].convert_nwp_to_numpy_batch()) + if "sat" in datapipes_dict: + numpy_modalities.append(datapipes_dict["sat"].convert_satellite_to_numpy_batch()) + if "pv" in datapipes_dict: + numpy_modalities.append(datapipes_dict["pv"].convert_pv_to_numpy_batch()) + numpy_modalities.append(datapipes_dict["gsp"].convert_gsp_to_numpy_batch()) + + logger.debug("Combine all the data sources") + combined_datapipe = MergeNumpyModalities(numpy_modalities).add_sun_position( + modality_name="gsp" + ) + + if self.block_sat and conf_sat != "": + sat_block_func = AddZeroedSatelliteData(self.configuration) + combined_datapipe = combined_datapipe.map(sat_block_func) + + if self.block_nwp and conf_nwp != "": + nwp_block_func = AddZeroedNWPData(self.configuration) + combined_datapipe = combined_datapipe.map(nwp_block_func) + + logger.info("Filtering out samples with no data") + if self.check_satellite_no_zeros: + # in production we don't want any nans in the satellite data + combined_datapipe = combined_datapipe.map(check_nans_in_satellite_data) + + combined_datapipe = combined_datapipe.map(fill_nans_in_arrays) + + yield combined_datapipe + + def _get_datapipes_dict( config_filename: str, block_sat: bool, @@ -695,55 +754,6 @@ def construct_sliced_data_pipeline( return finished_dataset_dict -def convert_to_numpy_batch( - datapipes_dict: dict[str, Union[IterDataPipe, Configuration]], - block_sat: bool = False, - block_nwp: bool = False, - check_satellite_no_zeros: bool = False, -): - configuration: Configuration = datapipes_dict["config"] - # Spatially slice, normalize, and convert data to numpy arrays - numpy_modalities = [] - # Unpack for convenience - conf_sat = configuration.input_data.satellite - conf_nwp = configuration.input_data.nwp - if "nwp" in datapipes_dict: - numpy_modalities.append(datapipes_dict["nwp"].convert_nwp_to_numpy_batch()) - if "sat" in datapipes_dict: - numpy_modalities.append(datapipes_dict["sat"].convert_satellite_to_numpy_batch()) - if "pv" in datapipes_dict: - numpy_modalities.append(datapipes_dict["pv"].convert_pv_to_numpy_batch()) - numpy_modalities.append(datapipes_dict["gsp"].convert_gsp_to_numpy_batch()) - - logger.debug("Combine all the data sources") - combined_datapipe = MergeNumpyModalities(numpy_modalities).add_sun_position(modality_name="gsp") - - if block_sat and conf_sat != "": - sat_block_func = AddZeroedSatelliteData(configuration) - combined_datapipe = combined_datapipe.map(sat_block_func) - - if block_nwp and conf_nwp != "": - nwp_block_func = AddZeroedNWPData(configuration) - combined_datapipe = combined_datapipe.map(nwp_block_func) - - logger.info("Filtering out samples with no data") - if check_satellite_no_zeros: - # in production we don't want any nans in the satellite data - combined_datapipe = combined_datapipe.map(check_nans_in_satellite_data) - - combined_datapipe = combined_datapipe.map(fill_nans_in_arrays) - - return combined_datapipe - - -def write_to_netcdf(datapipes_dict): - """ - Write the batch to a netcdf file. - """ - dataset = combine_to_single_dataset(datapipes_dict) - print(dataset) - - def windnet_datapipe( config_filename: str, start_time: Optional[datetime] = None, @@ -838,15 +848,9 @@ def windnet_netcdf_datapipe( filenames=filenames, keys=keys, configuration=configuration, - ) - # Split the dataset_dict_dp into dictionary of individual datapipes - datapipe_dict: dict[str:IterDataPipe] = datapipe_dict_dp.map(split_dataset_dict_dp) - - # Convert to numpy batch - datapipe = convert_to_numpy_batch( - datapipe_dict, - block_sat, - block_nwp, + ).map(split_dataset_dict_dp) + datapipe = datapipe_dict_dp.convert_to_numpy_batch( + block_nwp=block_nwp, block_sat=block_sat, configuration=configuration ) return datapipe @@ -896,10 +900,12 @@ def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: ) datasets = next(iter(dp)) dataset = combine_to_single_dataset(datasets) - dataset.to_zarr("test.nc", mode="w", compute=True) + print(dataset) + # Need to serialize attributes to strings + dataset.to_netcdf("test.nc", mode="w", engine="h5netcdf", compute=True) dp = windnet_netcdf_datapipe( config_filename=configuration_filename, - filenames=["test.zarr"], + filenames=["test.nc"], keys=["gsp", "nwp", "sat", "pv"], ) datasets = next(iter(dp)) diff --git a/ocf_datapipes/utils/utils.py b/ocf_datapipes/utils/utils.py index 275e6baa9..496bd43d7 100644 --- a/ocf_datapipes/utils/utils.py +++ b/ocf_datapipes/utils/utils.py @@ -345,6 +345,10 @@ def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset for key, datasets in dataset_dict.items(): new_datasets = [] for dataset in datasets: + # Convert all coordinates float64 and int64 to float32 and int32 + dataset = dataset.assign_attrs( + {key: str(value) for key, value in dataset.attrs.items()} + ) if isinstance(dataset, xr.DataArray): new_datasets.append(dataset.to_dataset(name=key)) else: @@ -361,7 +365,14 @@ def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset dataset = dataset.rename({coord: f"{key}__{coord}" for coord in dataset.coords}) batched_datasets.append(dataset) # Merge all datasets with the same key - dataset = xr.concat(batched_datasets, dim=f"{key}__time_utc") + # If NWP, then has init_time_utc and step, so do it off key__init_time_utc + dataset = xr.concat( + batched_datasets, + dim=f"{key}__target_time_utc" + if f"{key}__target_time_utc" in dataset.coords + else f"{key}__time_utc", + ) + # Serialize attributes to be JSON-seriaizable final_datasets_to_combined.append(dataset) # Combine all datasets, and append the list of datasets to the dataset_dict combined_dataset = xr.merge(final_datasets_to_combined) From f08118595a89e3e2523be8364f26b3aa75bf0bc8 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 12:03:38 +0000 Subject: [PATCH 08/29] Fix combine/uncombine for PVNet --- ocf_datapipes/training/windnet.py | 5 +++-- ocf_datapipes/utils/utils.py | 5 ++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index 068726038..d4b2cd343 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -5,6 +5,8 @@ import numpy as np import xarray as xr + +xr.set_options(keep_attrs=True) from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterableWrapper, IterDataPipe @@ -374,7 +376,7 @@ def __iter__(self): combined_datapipe = combined_datapipe.map(fill_nans_in_arrays) - yield combined_datapipe + yield next(iter(combined_datapipe)) def _get_datapipes_dict( @@ -900,7 +902,6 @@ def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: ) datasets = next(iter(dp)) dataset = combine_to_single_dataset(datasets) - print(dataset) # Need to serialize attributes to strings dataset.to_netcdf("test.nc", mode="w", engine="h5netcdf", compute=True) dp = windnet_netcdf_datapipe( diff --git a/ocf_datapipes/utils/utils.py b/ocf_datapipes/utils/utils.py index 496bd43d7..abbd5275a 100644 --- a/ocf_datapipes/utils/utils.py +++ b/ocf_datapipes/utils/utils.py @@ -376,6 +376,7 @@ def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset final_datasets_to_combined.append(dataset) # Combine all datasets, and append the list of datasets to the dataset_dict combined_dataset = xr.merge(final_datasets_to_combined) + # Print all attrbutes of the combined dataset return combined_dataset @@ -398,14 +399,12 @@ def uncombine_from_single_dataset(combined_dataset: xr.Dataset) -> dict[str, xr. for dim in dataset_dims: if f"{key}__" not in dim: dataset = dataset.drop(dim) - # print(dataset) dataset = dataset.rename( {dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords} ) - # print(dataset) dataset: xr.Dataset = dataset.rename( {coord: coord.split(f"{key}__")[1] for coord in dataset.coords} - ).to_dataset(name=key) + ) # Split the dataset by the prefix datasets[key] = dataset return datasets From 033aa95908afeca6c43607c370f5a7f2d98f95d9 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 13:26:02 +0000 Subject: [PATCH 09/29] Update tests --- tests/utils/test_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index fb958ac70..20632d738 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -27,4 +27,13 @@ def test_combine_uncombine_from_single_dataset(configuration_filename): multiple_datasets = uncombine_from_single_dataset(dataset) for key in multiple_datasets.keys(): for i in range(len(multiple_datasets[key].time_utc)): - assert datasets[key][i].equals(multiple_datasets[key].isel(time_utc=i)) + # Assert that coordinates are the same + assert ( + datasets[key][i].coords.keys() + == multiple_datasets[key].isel(time_utc=i).coords.keys() + ) + # Assert that data for each of the coords is the same + for coord_key in datasets[key][i].coords.keys(): + assert datasets[key][i][coord_key].equals( + multiple_datasets[key].isel(time_utc=i)[coord_key] + ) From 2dcfbd1612fde32c43a9d60056f77a90c6c5a420 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 13:35:45 +0000 Subject: [PATCH 10/29] Update test for NWP --- tests/utils/test_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 20632d738..4c5c8819a 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -26,14 +26,18 @@ def test_combine_uncombine_from_single_dataset(configuration_filename): assert isinstance(dataset, xr.Dataset) multiple_datasets = uncombine_from_single_dataset(dataset) for key in multiple_datasets.keys(): - for i in range(len(multiple_datasets[key].time_utc)): + if "time_utc" in multiple_datasets[key].coords.keys(): + time_coord = "time_utc" + else: + time_coord = "target_time_utc" + for i in range(len(multiple_datasets[key][time_coord])): # Assert that coordinates are the same assert ( datasets[key][i].coords.keys() - == multiple_datasets[key].isel(time_utc=i).coords.keys() + == multiple_datasets[key].isel({time_coord: i}).coords.keys() ) # Assert that data for each of the coords is the same for coord_key in datasets[key][i].coords.keys(): assert datasets[key][i][coord_key].equals( - multiple_datasets[key].isel(time_utc=i)[coord_key] + multiple_datasets[key].isel({time_coord: i})[coord_key] ) From 44c5efb9b9674d858907aedd0566e45776ed8970 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 14:12:12 +0000 Subject: [PATCH 11/29] Add test for windnet --- ocf_datapipes/utils/utils.py | 2 +- ocf_datapipes/validation/check_for_nans.py | 1 + tests/training/test_windnet.py | 28 ++++++++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 tests/training/test_windnet.py diff --git a/ocf_datapipes/utils/utils.py b/ocf_datapipes/utils/utils.py index abbd5275a..b1c5bc25a 100644 --- a/ocf_datapipes/utils/utils.py +++ b/ocf_datapipes/utils/utils.py @@ -336,7 +336,7 @@ def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset Combine multiple datasets into a single dataset Args: - *datasets: Datasets to combine + dataset_dict: Dictionary of xr.Dataset objects to combine Returns: Combined dataset diff --git a/ocf_datapipes/validation/check_for_nans.py b/ocf_datapipes/validation/check_for_nans.py index 421ce2a83..dfea01198 100644 --- a/ocf_datapipes/validation/check_for_nans.py +++ b/ocf_datapipes/validation/check_for_nans.py @@ -25,6 +25,7 @@ def __init__( source_datapipe: Datapipe emitting Xarray Datasets dataset_name: Optional name for dataset to check, if None, checks whole dataset fill_nans: Whether to fill NaNs with 0 or not + fill_value: Value to fill NaNs with """ self.source_datapipe = source_datapipe self.dataset_name = dataset_name diff --git a/tests/training/test_windnet.py b/tests/training/test_windnet.py new file mode 100644 index 000000000..f9919d90a --- /dev/null +++ b/tests/training/test_windnet.py @@ -0,0 +1,28 @@ +from datetime import datetime + +from ocf_datapipes.training.windnet import ( + windnet_datapipe, + windnet_netcdf_datapipe, +) +from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset +import pytest + + +def test_windnet_datapipe(configuration_filename): + start_time = datetime(1900, 1, 1) + end_time = datetime(2050, 1, 1) + dp = windnet_datapipe( + configuration_filename, + start_time=start_time, + end_time=end_time, + ) + datasets = next(iter(dp)) + dataset = combine_to_single_dataset(datasets) + # Need to serialize attributes to strings + dataset.to_netcdf("test.nc", mode="w", engine="h5netcdf", compute=True) + dp = windnet_netcdf_datapipe( + config_filename=configuration_filename, + filenames=["test.nc"], + keys=["gsp", "nwp", "sat", "pv"], + ) + datasets = next(iter(dp)) From f10968cea3f8b40bbaf318e5dad21ac0c4c5a122 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 14:19:34 +0000 Subject: [PATCH 12/29] Refactor out common functions --- ocf_datapipes/training/common.py | 572 +++++++++++++++++++++++++++++- ocf_datapipes/training/pvnet.py | 33 -- ocf_datapipes/training/windnet.py | 557 +---------------------------- tests/training/test_pvnet.py | 2 +- 4 files changed, 590 insertions(+), 574 deletions(-) diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index 7c668f331..5dd3d0f52 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -1,7 +1,11 @@ """Common functionality for datapipes""" import logging -from datetime import timedelta +from datetime import timedelta, datetime +from typing import List, Union, Optional, Tuple, Dict +import numpy as np +import xarray as xr +from torch.utils.data import functional_datapipe from torchdata.datapipes.iter import IterDataPipe from ocf_datapipes.config.model import Configuration @@ -13,6 +17,10 @@ OpenSatellite, OpenTopography, ) +from ocf_datapipes.load.gsp.database import OpenGSPFromDatabaseIterDataPipe +from ocf_datapipes.load.pv.database import OpenPVFromPVSitesDBIterDataPipe +from ocf_datapipes.training.pvnet import logger +from ocf_datapipes.utils.consts import NumpyBatch, BatchKey logger = logging.getLogger(__name__) @@ -432,3 +440,565 @@ def create_t0_and_loc_datapipes( location_pipe, t0_datapipe = t0_loc_datapipe.unzip(sequence_length=2) return location_pipe, t0_datapipe + + +def normalize_gsp(x): + """Normalize the GSP data + + Args: + x: Input DataArray + + Returns: + Normalized DataArray + """ + return x / x.effective_capacity_mwp + + +def normalize_pv(x): + """Normalize the PV data + + Args: + x: Input DataArray + + Returns: + Normalized DataArray + """ + return (x / x.nominal_capacity_wp).clip(None, 5) + + +def production_sat_scale(x): + """Scale the production satellite data + + Args: + x: Input DataArray + + Returns: + Scaled DataArray + """ + return x / 1024 + + +def concat_xr_time_utc(gsp_dataarrays: List[xr.DataArray]): + """This function is used to combine the split history and future gsp/pv dataarrays. + + These are split inside the `slice_datapipes_by_time()` function below. + + Splitting them inside that function allows us to apply dropout to the + history GSP/PV whilst leaving the future GSP/PV without NaNs. + + We recombine the history and future with this function to allow us to use the + `MergeNumpyModalities()` datapipe without redefining the BatchKeys. + + The `pvnet` model was also written to use a GSP/PV array which has historical and future + and to split it out. These maintains that assumption. + """ + return xr.concat(gsp_dataarrays, dim="time_utc") + + +def gsp_drop_national(x: Union[xr.DataArray, xr.Dataset]): + """Drop entries for national PV output + + Args: + x: Data source of gsp data + + Returns: + Filtered data source + """ + return x.where(x.gsp_id != 0, drop=True) + + +@functional_datapipe("pvnet_select_pv_by_ml_id") +class PVNetSelectPVbyMLIDIterDataPipe(IterDataPipe): + """Select specific set of PV systems by ML ID.""" + + def __init__(self, source_datapipe: IterDataPipe, ml_ids: np.array): + """Select specific set of PV systems by ML ID. + + Args: + source_datapipe: Datapipe emitting PV xarray data + ml_ids: List-like of ML IDs to select + + Returns: + Filtered data source + """ + self.source_datapipe = source_datapipe + self.ml_ids = ml_ids + + def __iter__(self): + for x in self.source_datapipe: + # Check for missing IDs + ml_ids_not_in_data = ~np.isin(self.ml_ids, x.ml_id) + if ml_ids_not_in_data.any(): + missing_ml_ids = np.array(self.ml_ids)[ml_ids_not_in_data] + logger.warning( + f"The following ML IDs were mising in the PV site-level input data: " + f"{missing_ml_ids}. The values for these IDs will be set to NaN." + ) + + x_filtered = ( + # Many ML-IDs are null, so filter first + x.where(~x.ml_id.isnull(), drop=True) + # Swap dimensions so we can select by ml_id coordinate + .swap_dims({"pv_system_id": "ml_id"}) + # Select IDs - missing IDs are given NaN values + .reindex(ml_id=self.ml_ids) + # Swap back dimensions + .swap_dims({"ml_id": "pv_system_id"}) + ) + yield x_filtered + + +def fill_nans_in_pv(x: Union[xr.DataArray, xr.Dataset]): + """Fill NaNs in PV data with the value -1 + + Args: + x: Input DataArray + + Returns: + Normalized DataArray + """ + return x.fillna(-1) + + +def fill_nans_in_arrays(batch: NumpyBatch) -> NumpyBatch: + """Fills all NaN values in each np.ndarray in the batch dictionary with zeros. + + Operation is performed in-place on the batch. + """ + logger.info("Filling Nans with zeros") + for k, v in batch.items(): + if isinstance(v, np.ndarray): + np.nan_to_num(v, copy=False, nan=0.0) + return batch + + +class AddZeroedSatelliteData: + """A callable class used to add zeroed-out satellite data to batches of data. + + This is useful + to speed up batch loading if pre-training the output part of the network without satellite + inputs. + """ + + def __init__(self, configuration: Configuration, is_hrv: bool = False): + """A callable class used to add zeroed-out satellite data to batches of data. + + Args: + configuration: Configuration object + is_hrv: If False, non-HRV data is added by called function, else HRV. + """ + + self.configuration = configuration + self.is_hrv = is_hrv + + def __call__(self, batch: NumpyBatch) -> NumpyBatch: + """Add zeroed-out satellite data to batch with shape accoriding to supplied configuration. + + Batch is modified in-place and returned. + + Args: + batch: Numpy batch of input data. + """ + + variable = "hrvsatellite" if self.is_hrv else "satellite" + + satellite_config = getattr(self.configuration.input_data, variable) + + n_channels = len(getattr(satellite_config, f"{variable}_channels")) + height = getattr(satellite_config, f"{variable}_image_size_pixels_height") + width = getattr(satellite_config, f"{variable}_image_size_pixels_width") + + sequence_len = satellite_config.history_minutes // 5 + 1 - 3 + + batch[getattr(BatchKey, f"{variable}_actual")] = np.zeros( + (sequence_len, n_channels, height, width) + ) + + return batch + + +class AddZeroedNWPData: + """A callable class used to add zeroed-out NWP data to batches of data. + + This is useful to speed up batch loading if pre-training the output part of the network without + NWP inputs. + """ + + def __init__(self, configuration: Configuration): + """A callable class used to add zeroed-out NWP data to batches of data. + + Args: + configuration: Configuration object + """ + self.configuration = configuration + + def __call__(self, batch: NumpyBatch) -> NumpyBatch: + """Add zeroed-out NWP data to batch with shape accoriding to supplied configuration. + + Batch is modified in-place and returned. + + Args: + batch: Numpy batch of input data. + """ + + config = self.configuration.input_data.nwp + + n_channels = len(config.nwp_channels) + height = config.nwp_image_size_pixels_height + width = config.nwp_image_size_pixels_width + + sequence_len = config.history_minutes // 60 + config.forecast_minutes // 60 + 1 + + batch[BatchKey.nwp] = np.zeros((sequence_len, n_channels, height, width)) + + return batch + + +class DatapipeKeyForker: + """ "Internal helper function to track forking of a datapipe.""" + + def __init__(self, keys: List, datapipe: IterDataPipe): + """Internal helper function to track forking of a datapipe. + + As forks are returned, this object tracks the keys left and returns the final copy of the + datapipe when the last key is requested. This makes multiple forking easier and ensures + closure. + + Args: + keys: List of keys for which datapipe duplication is required. + datapipe: Datapipe which will be forked for each ket + """ + self.keys_left = keys + self.datapipe = datapipe + + def __call__(self, key): + """ "Returns a fork of `self.datapipe` and tracks a the keys left to ensure closure. + + Args: + key: key to remove from `self.keys_left`. If `key` is None then an extra copy is made + without affecting `self.keys_left`. + """ + if len(self.keys_left) == 0: + raise ValueError(f"No keys left when requested key : {key}") + if key is not None: + self.keys_left.remove(key) + if len(self.keys_left) > 0: + self.datapipe, return_datapipe = self.datapipe.fork(2, buffer_size=5) + else: + return_datapipe = self.datapipe + return return_datapipe + + def close(self): + """Asserts that the keys have all been used.""" + assert len(self.keys_left) == 0 + + +def _get_datapipes_dict( + config_filename: str, + block_sat: bool, + block_nwp: bool, + production: bool = False, +): + # Load datasets + datapipes_dict = open_and_return_datapipes( + configuration_filename=config_filename, + use_gsp=(not production), + use_pv=(not production), + use_sat=not block_sat, # Only loaded if we aren't replacing them with zeros + use_hrv=False, + use_nwp=not block_nwp, # Only loaded if we aren't replacing them with zeros + use_topo=False, + production=production, + ) + + config: Configuration = datapipes_dict["config"] + + if production: + datapipes_dict["gsp"] = OpenGSPFromDatabase().add_t0_idx_and_sample_period_duration( + sample_period_duration=timedelta(minutes=30), + history_duration=timedelta(minutes=config.input_data.gsp.history_minutes), + ) + if "sat" in datapipes_dict: + datapipes_dict["sat"] = datapipes_dict["sat"].map(production_sat_scale) + if "pv" in datapipes_dict: + datapipes_dict["pv"] = OpenPVFromPVSitesDB(config.input_data.pv.history_minutes) + + if "pv" in datapipes_dict and config.input_data.pv.pv_ml_ids != []: + datapipes_dict["pv"] = datapipes_dict["pv"].pvnet_select_pv_by_ml_id( + config.input_data.pv.pv_ml_ids + ) + + return datapipes_dict + + +def construct_loctime_pipelines( + config_filename: str, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + block_sat: bool = False, + block_nwp: bool = False, +) -> Tuple[IterDataPipe, IterDataPipe]: + """Construct location and time pipelines for the input data config file. + + Args: + config_filename: Path to config file. + start_time: Minimum time for time datapipe. + end_time: Maximum time for time datapipe. + block_sat: Whether to load zeroes for satellite data. + block_nwp: Whether to load zeroes for NWP data. + """ + + datapipes_dict = _get_datapipes_dict( + config_filename, + block_sat, + block_nwp, + ) + + # Pull out config file + config = datapipes_dict.pop("config") + + # We sample time and space of other data using GSP time and space coordinates, so filter GSP + # data first amd this is carried through + datapipes_dict["gsp"] = datapipes_dict["gsp"].map(gsp_drop_national) + if (start_time is not None) or (end_time is not None): + datapipes_dict["gsp"] = datapipes_dict["gsp"].select_train_test_time(start_time, end_time) + + # Get overlapping time periods + location_pipe, t0_datapipe = create_t0_and_loc_datapipes( + datapipes_dict, + configuration=config, + key_for_t0="gsp", + shuffle=True, + nwp_max_dropout_minutes=180, + ) + + return location_pipe, t0_datapipe + + +def minutes(num_mins: int): + """Timedelta of a number of minutes. + + Args: + num_mins: Minutes timedelta. + """ + return timedelta(minutes=num_mins) + + +def slice_datapipes_by_time( + datapipes_dict: Dict, + t0_datapipe: IterDataPipe, + configuration: Configuration, + production: bool = False, +) -> None: + """ + Modifies a dictionary of datapipes in-place to yield samples for given times t0. + + The NWP data* will be at least 90 minutes stale (i.e. as if it takes 90 minutes for the foreast + to become available). + + The satellite data* is shaped so that the most recent can be 15 minutes before t0. However, 50% + of the time dropout is applied so that the most recent field is between 45 and 20 minutes before + t0. When dropped out like this, the values after this selected dropout time are set to NaN. + + The HRV data* is similar to the satellite data and if both are included they drop out + simulataneously. + + The GSP data is split into "gsp" and "gsp_future" keys. 10% of the time the gsp value for time + t0, which occurs under the "gsp" key, is set to NaN + + The PV data* is also split it "pv" and "pv_future" keys. + + * if included + + n.b. PV and HRV are included in this function, but not yet in the rest of the pvnet pipeline. + This is mostly for demonstratio purposes of how the dropout might be applied. + + Args: + datapipes_dict: Dictionary of used datapipes and t0 ones + t0_datapipe: Datapipe which yields t0 times for sample + configuration: Configuration object. + production: Whether constucting pipeline for production inference. No dropout is used if + True. + + """ + + conf_in = configuration.input_data + + # Use DatapipeKeyForker to avoid forking t0_datapipe too many times, or leaving any forks unused + fork_keys = {k for k in datapipes_dict.keys() if k not in ["topo"]} + get_t0_datapipe = DatapipeKeyForker(fork_keys, t0_datapipe) + + sat_and_hrv_dropout_kwargs = dict( + # Satellite is either 30 minutes or 60 minutes delayed in production. Match during training + dropout_timedeltas=[minutes(-60), minutes(-30)], + dropout_frac=0 if production else 1.0, + ) + + sat_delay = minutes(-configuration.input_data.satellite.live_delay_minutes) + + if "nwp" in datapipes_dict: + datapipes_dict["nwp"] = datapipes_dict["nwp"].convert_to_nwp_target_time_with_dropout( + t0_datapipe=get_t0_datapipe("nwp"), + sample_period_duration=minutes(60), + history_duration=minutes(conf_in.nwp.history_minutes), + forecast_duration=minutes(conf_in.nwp.forecast_minutes), + # The NWP forecast will always be at least 180 minutes stale + dropout_timedeltas=[minutes(-180)], + dropout_frac=0 if production else 1.0, + ) + + if "sat" in datapipes_dict: + # Take time slices of sat data + datapipes_dict["sat"] = datapipes_dict["sat"].select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(5), + interval_start=minutes(-conf_in.satellite.history_minutes), + interval_end=sat_delay, + fill_selection=production, + max_steps_gap=2, + ) + + # Generate randomly sampled dropout times + sat_dropout_time_datapipe = get_t0_datapipe("sat").select_dropout_time( + **sat_and_hrv_dropout_kwargs + ) + + if "hrv" in datapipes_dict: + # Make dropout-time copy for hrv if included in data. + # HRV and non-HRV will dropout simultaneously. + sat_dropout_time_datapipe, hrv_dropout_time_datapipe = sat_dropout_time_datapipe.fork( + 2, buffer_size=5 + ) + + # Apply the dropout + datapipes_dict["sat"] = datapipes_dict["sat"].apply_dropout_time( + dropout_time_datapipe=sat_dropout_time_datapipe, + ) + + if "hrv" in datapipes_dict: + if "sat" not in datapipes_dict: + # Generate randomly sampled dropout times + # This is shared with sat if sat included + hrv_dropout_time_datapipe = get_t0_datapipe(None).select_dropout_time( + **sat_and_hrv_dropout_kwargs + ) + + datapipes_dict["hrv"] = datapipes_dict["hrv"].select_time_slice( + t0_datapipe=get_t0_datapipe("hrv"), + sample_period_duration=minutes(5), + interval_start=minutes(-conf_in.hrvsatellite.history_minutes), + interval_end=sat_delay, + fill_selection=production, + max_steps_gap=2, + ) + + # Apply the dropout + datapipes_dict["hrv"] = datapipes_dict["hrv"].apply_dropout_time( + dropout_time_datapipe=hrv_dropout_time_datapipe, + ) + + if "pv" in datapipes_dict: + datapipes_dict["pv"], dp = datapipes_dict["pv"].fork(2, buffer_size=5) + + datapipes_dict["pv_future"] = dp.select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(5), + interval_start=minutes(5), + interval_end=minutes(conf_in.pv.forecast_minutes), + fill_selection=production, + ) + + datapipes_dict["pv"] = datapipes_dict["pv"].select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(5), + interval_start=minutes(-conf_in.pv.history_minutes), + interval_end=minutes(0), + fill_selection=production, + ) + + # Dropout on the PV, but not the future PV + pv_dropout_time_datapipe = get_t0_datapipe("pv").select_dropout_time( + # All PV data could be delayed by up to 30 minutes + # (this does not stem from production - just setting for now) + dropout_timedeltas=[minutes(m) for m in range(-30, 0, 5)], + dropout_frac=0.1 if production else 1, + ) + + datapipes_dict["pv"] = datapipes_dict["pv"].apply_dropout_time( + dropout_time_datapipe=pv_dropout_time_datapipe, + ) + + # Apply extra PV dropout using different delays per system and droping out entire PV systems + # independently + if not production: + datapipes_dict["pv"].apply_pv_dropout( + system_dropout_fractions=np.linspace(0, 0.2, 100), + system_dropout_timedeltas=[minutes(m) for m in [-15, -10, -5, 0]], + ) + + if "gsp" in datapipes_dict: + datapipes_dict["gsp"], dp = datapipes_dict["gsp"].fork(2, buffer_size=5) + + datapipes_dict["gsp_future"] = dp.select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(30), + interval_start=minutes(30), + interval_end=minutes(conf_in.gsp.forecast_minutes), + fill_selection=production, + ) + + datapipes_dict["gsp"] = datapipes_dict["gsp"].select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(30), + interval_start=-minutes(conf_in.gsp.history_minutes), + interval_end=minutes(0), + fill_selection=production, + ) + + # Dropout on the GSP, but not the future GSP + gsp_dropout_time_datapipe = get_t0_datapipe("gsp").select_dropout_time( + # GSP data for time t0 may be missing. Only have value for t0-30mins + dropout_timedeltas=[minutes(-30)], + dropout_frac=0 if production else 0.1, + ) + + datapipes_dict["gsp"] = datapipes_dict["gsp"].apply_dropout_time( + dropout_time_datapipe=gsp_dropout_time_datapipe, + ) + + get_t0_datapipe.close() + + return + + +def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: + """ + Check if there are any Nans values in the satellite data. + """ + if np.any(np.isnan(batch[BatchKey.satellite_actual])): + logger.error("Found nans values in satellite data") + + logger.error(batch[BatchKey.satellite_actual].shape) + + # loop over time and channels + for dim in [0, 1]: + for t in range(batch[BatchKey.satellite_actual].shape[dim]): + if dim == 0: + sate_data_one_step = batch[BatchKey.satellite_actual][t] + else: + sate_data_one_step = batch[BatchKey.satellite_actual][:, t] + nans = np.isnan(sate_data_one_step) + + if np.any(nans): + percent_nans = np.sum(nans) / np.prod(sate_data_one_step.shape) * 100 + + logger.error( + f"Found nans values in satellite data at index {t} ({dim=}). " + f"{percent_nans}% of values are nans" + ) + else: + logger.error(f"Found no nans values in satellite data at index {t} {dim=}") + + raise ValueError("Found nans values in satellite data") + + return batch diff --git a/ocf_datapipes/training/pvnet.py b/ocf_datapipes/training/pvnet.py index 7f8788481..537c1e039 100644 --- a/ocf_datapipes/training/pvnet.py +++ b/ocf_datapipes/training/pvnet.py @@ -725,36 +725,3 @@ def pvnet_datapipe( ) return datapipe - - -def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: - """ - Check if there are any Nans values in the satellite data. - """ - if np.any(np.isnan(batch[BatchKey.satellite_actual])): - logger.error("Found nans values in satellite data") - - logger.error(batch[BatchKey.satellite_actual].shape) - - # loop over time and channels - for dim in [0, 1]: - for t in range(batch[BatchKey.satellite_actual].shape[dim]): - if dim == 0: - sate_data_one_step = batch[BatchKey.satellite_actual][t] - else: - sate_data_one_step = batch[BatchKey.satellite_actual][:, t] - nans = np.isnan(sate_data_one_step) - - if np.any(nans): - percent_nans = np.sum(nans) / np.prod(sate_data_one_step.shape) * 100 - - logger.error( - f"Found nans values in satellite data at index {t} ({dim=}). " - f"{percent_nans}% of values are nans" - ) - else: - logger.error(f"Found no nans values in satellite data at index {t} {dim=}") - - raise ValueError("Found nans values in satellite data") - - return batch diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index d4b2cd343..c0e467139 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -3,7 +3,6 @@ from datetime import datetime, timedelta from typing import Dict, List, Optional, Tuple, Union -import numpy as np import xarray as xr xr.set_options(keep_attrs=True) @@ -14,12 +13,6 @@ from ocf_datapipes.config.model import Configuration from ocf_datapipes.load import ( OpenConfiguration, - OpenGSPFromDatabase, - OpenPVFromPVSitesDB, -) -from ocf_datapipes.training.common import ( - create_t0_and_loc_datapipes, - open_and_return_datapipes, ) from ocf_datapipes.utils.consts import ( NEW_NWP_MEAN, @@ -30,86 +23,22 @@ NumpyBatch, ) from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset +from ocf_datapipes.training.common import ( + normalize_gsp, + normalize_pv, + concat_xr_time_utc, + fill_nans_in_pv, + fill_nans_in_arrays, + AddZeroedSatelliteData, + AddZeroedNWPData, + _get_datapipes_dict, + construct_loctime_pipelines, + slice_datapipes_by_time, + check_nans_in_satellite_data, +) xr.set_options(keep_attrs=True) -logger = logging.getLogger("pvnet_datapipe") - - -def normalize_gsp(x): - """Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.effective_capacity_mwp - - -def normalize_pv(x): - """Normalize the PV data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return (x / x.nominal_capacity_wp).clip(None, 5) - - -def production_sat_scale(x): - """Scale the production satellite data - - Args: - x: Input DataArray - - Returns: - Scaled DataArray - """ - return x / 1024 - - -def concat_xr_time_utc(gsp_dataarrays: List[xr.DataArray]): - """This function is used to combine the split history and future gsp/pv dataarrays. - - These are split inside the `slice_datapipes_by_time()` function below. - - Splitting them inside that function allows us to apply dropout to the - history GSP/PV whilst leaving the future GSP/PV without NaNs. - - We recombine the history and future with this function to allow us to use the - `MergeNumpyModalities()` datapipe without redefining the BatchKeys. - - The `pvnet` model was also written to use a GSP/PV array which has historical and future - and to split it out. These maintains that assumption. - """ - return xr.concat(gsp_dataarrays, dim="time_utc") - - -def gsp_drop_national(x: Union[xr.DataArray, xr.Dataset]): - """Drop entries for national PV output - - Args: - x: Data source of gsp data - - Returns: - Filtered data source - """ - return x.where(x.gsp_id != 0, drop=True) - - -def fill_nans_in_pv(x: Union[xr.DataArray, xr.Dataset]): - """Fill NaNs in PV data with the value -1 - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x.fillna(-1) +logger = logging.getLogger("windnet_datapipe") def scale_wind_speed_to_power(x: Union[xr.DataArray, xr.Dataset]): @@ -131,139 +60,6 @@ def scale_wind_speed_to_power(x: Union[xr.DataArray, xr.Dataset]): return x -def fill_nans_in_arrays(batch: NumpyBatch) -> NumpyBatch: - """Fills all NaN values in each np.ndarray in the batch dictionary with zeros. - - Operation is performed in-place on the batch. - """ - logger.info("Filling Nans with zeros") - for k, v in batch.items(): - if isinstance(v, np.ndarray): - np.nan_to_num(v, copy=False, nan=0.0) - return batch - - -class AddZeroedSatelliteData: - """A callable class used to add zeroed-out satellite data to batches of data. - - This is useful - to speed up batch loading if pre-training the output part of the network without satellite - inputs. - """ - - def __init__(self, configuration: Configuration, is_hrv: bool = False): - """A callable class used to add zeroed-out satellite data to batches of data. - - Args: - configuration: Configuration object - is_hrv: If False, non-HRV data is added by called function, else HRV. - """ - - self.configuration = configuration - self.is_hrv = is_hrv - - def __call__(self, batch: NumpyBatch) -> NumpyBatch: - """Add zeroed-out satellite data to batch with shape accoriding to supplied configuration. - - Batch is modified in-place and returned. - - Args: - batch: Numpy batch of input data. - """ - - variable = "hrvsatellite" if self.is_hrv else "satellite" - - satellite_config = getattr(self.configuration.input_data, variable) - - n_channels = len(getattr(satellite_config, f"{variable}_channels")) - height = getattr(satellite_config, f"{variable}_image_size_pixels_height") - width = getattr(satellite_config, f"{variable}_image_size_pixels_width") - - sequence_len = satellite_config.history_minutes // 5 + 1 - 3 - - batch[getattr(BatchKey, f"{variable}_actual")] = np.zeros( - (sequence_len, n_channels, height, width) - ) - - return batch - - -class AddZeroedNWPData: - """A callable class used to add zeroed-out NWP data to batches of data. - - This is useful to speed up batch loading if pre-training the output part of the network without - NWP inputs. - """ - - def __init__(self, configuration: Configuration): - """A callable class used to add zeroed-out NWP data to batches of data. - - Args: - configuration: Configuration object - """ - self.configuration = configuration - - def __call__(self, batch: NumpyBatch) -> NumpyBatch: - """Add zeroed-out NWP data to batch with shape accoriding to supplied configuration. - - Batch is modified in-place and returned. - - Args: - batch: Numpy batch of input data. - """ - - config = self.configuration.input_data.nwp - - n_channels = len(config.nwp_channels) - height = config.nwp_image_size_pixels_height - width = config.nwp_image_size_pixels_width - - sequence_len = config.history_minutes // 60 + config.forecast_minutes // 60 + 1 - - batch[BatchKey.nwp] = np.zeros((sequence_len, n_channels, height, width)) - - return batch - - -class DatapipeKeyForker: - """ "Internal helper function to track forking of a datapipe.""" - - def __init__(self, keys: List, datapipe: IterDataPipe): - """Internal helper function to track forking of a datapipe. - - As forks are returned, this object tracks the keys left and returns the final copy of the - datapipe when the last key is requested. This makes multiple forking easier and ensures - closure. - - Args: - keys: List of keys for which datapipe duplication is required. - datapipe: Datapipe which will be forked for each ket - """ - self.keys_left = keys - self.datapipe = datapipe - - def __call__(self, key): - """ "Returns a fork of `self.datapipe` and tracks a the keys left to ensure closure. - - Args: - key: key to remove from `self.keys_left`. If `key` is None then an extra copy is made - without affecting `self.keys_left`. - """ - if len(self.keys_left) == 0: - raise ValueError(f"No keys left when requested key : {key}") - if key is not None: - self.keys_left.remove(key) - if len(self.keys_left) > 0: - self.datapipe, return_datapipe = self.datapipe.fork(2, buffer_size=5) - else: - return_datapipe = self.datapipe - return return_datapipe - - def close(self): - """Asserts that the keys have all been used.""" - assert len(self.keys_left) == 0 - - @functional_datapipe("dict_datasets") class DictDatasetIterDataPipe(IterDataPipe): """ """ @@ -379,84 +175,6 @@ def __iter__(self): yield next(iter(combined_datapipe)) -def _get_datapipes_dict( - config_filename: str, - block_sat: bool, - block_nwp: bool, - production: bool = False, -): - # Load datasets - datapipes_dict = open_and_return_datapipes( - configuration_filename=config_filename, - use_gsp=(not production), - use_pv=(not production), - use_sat=not block_sat, # Only loaded if we aren't replacing them with zeros - use_hrv=False, - use_nwp=not block_nwp, # Only loaded if we aren't replacing them with zeros - use_topo=False, - production=production, - ) - - config: Configuration = datapipes_dict["config"] - - if production: - datapipes_dict["gsp"] = OpenGSPFromDatabase().add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=config.input_data.gsp.history_minutes), - ) - if "sat" in datapipes_dict: - datapipes_dict["sat"] = datapipes_dict["sat"].map(production_sat_scale) - if "pv" in datapipes_dict: - datapipes_dict["pv"] = OpenPVFromPVSitesDB(config.input_data.pv.history_minutes) - - return datapipes_dict - - -def construct_loctime_pipelines( - config_filename: str, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, - block_sat: bool = False, - block_nwp: bool = False, -) -> Tuple[IterDataPipe, IterDataPipe]: - """Construct location and time pipelines for the input data config file. - - Args: - config_filename: Path to config file. - start_time: Minimum time for time datapipe. - end_time: Maximum time for time datapipe. - block_sat: Whether to load zeroes for satellite data. - block_nwp: Whether to load zeroes for NWP data. - """ - - datapipes_dict = _get_datapipes_dict( - config_filename, - block_sat, - block_nwp, - ) - - # Pull out config file - config = datapipes_dict.pop("config") - - # We sample time and space of other data using GSP time and space coordinates, so filter GSP - # data first amd this is carried through - # Map from wind speed to m/s here - datapipes_dict["gsp"] = datapipes_dict["gsp"] - if (start_time is not None) or (end_time is not None): - datapipes_dict["gsp"] = datapipes_dict["gsp"].select_train_test_time(start_time, end_time) - - # Get overlapping time periods - location_pipe, t0_datapipe = create_t0_and_loc_datapipes( - datapipes_dict, - configuration=config, - key_for_t0="gsp", - shuffle=True, - nwp_max_dropout_minutes=180, - ) - - return location_pipe, t0_datapipe - - def minutes(num_mins: int): """Timedelta of a number of minutes. @@ -466,193 +184,6 @@ def minutes(num_mins: int): return timedelta(minutes=num_mins) -def slice_datapipes_by_time( - datapipes_dict: Dict, - t0_datapipe: IterDataPipe, - configuration: Configuration, - production: bool = False, -) -> None: - """ - Modifies a dictionary of datapipes in-place to yield samples for given times t0. - - The NWP data* will be at least 90 minutes stale (i.e. as if it takes 90 minutes for the foreast - to become available). - - The satellite data* is shaped so that the most recent can be 15 minutes before t0. However, 50% - of the time dropout is applied so that the most recent field is between 45 and 20 minutes before - t0. When dropped out like this, the values after this selected dropout time are set to NaN. - - The HRV data* is similar to the satellite data and if both are included they drop out - simulataneously. - - The GSP data is split into "gsp" and "gsp_future" keys. 10% of the time the gsp value for time - t0, which occurs under the "gsp" key, is set to NaN - - The PV data* is also split it "pv" and "pv_future" keys. - - * if included - - n.b. PV and HRV are included in this function, but not yet in the rest of the pvnet pipeline. - This is mostly for demonstratio purposes of how the dropout might be applied. - - Args: - datapipes_dict: Dictionary of used datapipes and t0 ones - t0_datapipe: Datapipe which yields t0 times for sample - configuration: Configuration object. - production: Whether constucting pipeline for production inference. No dropout is used if - True. - - """ - - conf_in = configuration.input_data - - # Use DatapipeKeyForker to avoid forking t0_datapipe too many times, or leaving any forks unused - fork_keys = {k for k in datapipes_dict.keys() if k not in ["topo"]} - get_t0_datapipe = DatapipeKeyForker(fork_keys, t0_datapipe) - - sat_and_hrv_dropout_kwargs = dict( - # Satellite is either 30 minutes or 60 minutes delayed in production. Match during training - dropout_timedeltas=[minutes(-60), minutes(-30)], - dropout_frac=0 if production else 1.0, - ) - - sat_delay = minutes(-configuration.input_data.satellite.live_delay_minutes) - - if "nwp" in datapipes_dict: - datapipes_dict["nwp"] = datapipes_dict["nwp"].convert_to_nwp_target_time_with_dropout( - t0_datapipe=get_t0_datapipe("nwp"), - sample_period_duration=minutes(60), - history_duration=minutes(conf_in.nwp.history_minutes), - forecast_duration=minutes(conf_in.nwp.forecast_minutes), - # The NWP forecast will always be at least 180 minutes stale - dropout_timedeltas=[minutes(-180)], - dropout_frac=0 if production else 1.0, - ) - - if "sat" in datapipes_dict: - # Take time slices of sat data - datapipes_dict["sat"] = datapipes_dict["sat"].select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(5), - interval_start=minutes(-conf_in.satellite.history_minutes), - interval_end=sat_delay, - fill_selection=production, - max_steps_gap=2, - ) - - # Generate randomly sampled dropout times - sat_dropout_time_datapipe = get_t0_datapipe("sat").select_dropout_time( - **sat_and_hrv_dropout_kwargs - ) - - if "hrv" in datapipes_dict: - # Make dropout-time copy for hrv if included in data. - # HRV and non-HRV will dropout simultaneously. - sat_dropout_time_datapipe, hrv_dropout_time_datapipe = sat_dropout_time_datapipe.fork( - 2, buffer_size=5 - ) - - # Apply the dropout - datapipes_dict["sat"] = datapipes_dict["sat"].apply_dropout_time( - dropout_time_datapipe=sat_dropout_time_datapipe, - ) - - if "hrv" in datapipes_dict: - if "sat" not in datapipes_dict: - # Generate randomly sampled dropout times - # This is shared with sat if sat included - hrv_dropout_time_datapipe = get_t0_datapipe(None).select_dropout_time( - **sat_and_hrv_dropout_kwargs - ) - - datapipes_dict["hrv"] = datapipes_dict["hrv"].select_time_slice( - t0_datapipe=get_t0_datapipe("hrv"), - sample_period_duration=minutes(5), - interval_start=minutes(-conf_in.hrvsatellite.history_minutes), - interval_end=sat_delay, - fill_selection=production, - max_steps_gap=2, - ) - - # Apply the dropout - datapipes_dict["hrv"] = datapipes_dict["hrv"].apply_dropout_time( - dropout_time_datapipe=hrv_dropout_time_datapipe, - ) - - if "pv" in datapipes_dict: - datapipes_dict["pv"], dp = datapipes_dict["pv"].fork(2, buffer_size=5) - - datapipes_dict["pv_future"] = dp.select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(5), - interval_start=minutes(5), - interval_end=minutes(conf_in.pv.forecast_minutes), - fill_selection=production, - ) - - datapipes_dict["pv"] = datapipes_dict["pv"].select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(5), - interval_start=minutes(-conf_in.pv.history_minutes), - interval_end=minutes(0), - fill_selection=production, - ) - - # Dropout on the PV, but not the future PV - pv_dropout_time_datapipe = get_t0_datapipe("pv").select_dropout_time( - # All PV data could be delayed by up to 30 minutes - # (this does not stem from production - just setting for now) - dropout_timedeltas=[minutes(m) for m in range(-30, 0, 5)], - dropout_frac=0.1 if production else 1, - ) - - datapipes_dict["pv"] = datapipes_dict["pv"].apply_dropout_time( - dropout_time_datapipe=pv_dropout_time_datapipe, - ) - - # Apply extra PV dropout using different delays per system and droping out entire PV systems - # independently - if not production: - datapipes_dict["pv"].apply_pv_dropout( - system_dropout_fractions=np.linspace(0, 0.2, 100), - system_dropout_timedeltas=[minutes(m) for m in [-15, -10, -5, 0]], - ) - - if "gsp" in datapipes_dict: - datapipes_dict["gsp"], dp = datapipes_dict["gsp"].fork(2, buffer_size=5) - - datapipes_dict["gsp_future"] = dp.select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(30), - interval_start=minutes(30), - interval_end=minutes(conf_in.gsp.forecast_minutes), - fill_selection=production, - ) - - datapipes_dict["gsp"] = datapipes_dict["gsp"].select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(30), - interval_start=-minutes(conf_in.gsp.history_minutes), - interval_end=minutes(0), - fill_selection=production, - ) - - # Dropout on the GSP, but not the future GSP - gsp_dropout_time_datapipe = get_t0_datapipe("gsp").select_dropout_time( - # GSP data for time t0 may be missing. Only have value for t0-30mins - dropout_timedeltas=[minutes(-30)], - dropout_frac=0 if production else 0.1, - ) - - datapipes_dict["gsp"] = datapipes_dict["gsp"].apply_dropout_time( - dropout_time_datapipe=gsp_dropout_time_datapipe, - ) - - get_t0_datapipe.close() - - return - - def construct_sliced_data_pipeline( config_filename: str, location_pipe: IterDataPipe, @@ -672,7 +203,6 @@ def construct_sliced_data_pipeline( block_sat: Whether to load zeroes for satellite data. block_nwp: Whether to load zeroes for NWP data. production: Whether constucting pipeline for production inference. - check_satellite_no_zeros: Whether to check that satellite data has no zeros. """ assert not (production and (block_sat or block_nwp)) @@ -819,6 +349,9 @@ def windnet_datapipe( def split_dataset_dict_dp(element): """ Split the dictionary of datapipes into individual datapipes + + Args: + element: Dictionary of datapipes """ return {k: IterableWrapper([v]) for k, v in element.items() if k != "config"} @@ -836,6 +369,7 @@ def windnet_netcdf_datapipe( Args: config_filename: Path to config file. keys: List of keys to extract from the single NetCDF files + filenames: List of NetCDF files to load block_sat: Whether to load zeroes for satellite data. block_nwp: Whether to load zeroes for NWP data. @@ -856,58 +390,3 @@ def windnet_netcdf_datapipe( ) return datapipe - - -def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: - """ - Check if there are any Nans values in the satellite data. - """ - if np.any(np.isnan(batch[BatchKey.satellite_actual])): - logger.error("Found nans values in satellite data") - - logger.error(batch[BatchKey.satellite_actual].shape) - - # loop over time and channels - for dim in [0, 1]: - for t in range(batch[BatchKey.satellite_actual].shape[dim]): - if dim == 0: - sate_data_one_step = batch[BatchKey.satellite_actual][t] - else: - sate_data_one_step = batch[BatchKey.satellite_actual][:, t] - nans = np.isnan(sate_data_one_step) - - if np.any(nans): - percent_nans = np.sum(nans) / np.prod(sate_data_one_step.shape) * 100 - - logger.error( - f"Found nans values in satellite data at index {t} ({dim=}). " - f"{percent_nans}% of values are nans" - ) - else: - logger.error(f"Found no nans values in satellite data at index {t} {dim=}") - - raise ValueError("Found nans values in satellite data") - - return batch - - -if __name__ == "__main__": - start_time = datetime(1900, 1, 1) - end_time = datetime(2050, 1, 1) - configuration_filename = "/home/jacob/Development/ocf_datapipes/tests/config/test.yaml" - dp = windnet_datapipe( - configuration_filename, - start_time=start_time, - end_time=end_time, - ) - datasets = next(iter(dp)) - dataset = combine_to_single_dataset(datasets) - # Need to serialize attributes to strings - dataset.to_netcdf("test.nc", mode="w", engine="h5netcdf", compute=True) - dp = windnet_netcdf_datapipe( - config_filename=configuration_filename, - filenames=["test.nc"], - keys=["gsp", "nwp", "sat", "pv"], - ) - datasets = next(iter(dp)) - print(datasets) diff --git a/tests/training/test_pvnet.py b/tests/training/test_pvnet.py index 035aa493e..1c83b4fb5 100644 --- a/tests/training/test_pvnet.py +++ b/tests/training/test_pvnet.py @@ -3,10 +3,10 @@ from torchdata.datapipes.iter import IterableWrapper from ocf_datapipes.training.pvnet import ( - construct_loctime_pipelines, construct_sliced_data_pipeline, pvnet_datapipe, ) +from ocf_datapipes.training.common import construct_loctime_pipelines from ocf_datapipes.utils.consts import Location import pytest From cf6f392b53ce9f25ada7e5f4ed823bb324f5724a Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 14:20:34 +0000 Subject: [PATCH 13/29] Include combining in WindNet --- ocf_datapipes/training/windnet.py | 8 +------- tests/training/test_windnet.py | 3 +-- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index c0e467139..17fd24095 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -330,12 +330,6 @@ def windnet_datapipe( # Save out datapipe to NetCDF - # Convert to numpy batch - # datapipe = convert_to_numpy_batch( - # datapipe_dict, - # block_sat, - # block_nwp, - # ) # Merge all the datapipes into one return DictDatasetIterDataPipe( datapipe_dict["gsp"], @@ -343,7 +337,7 @@ def windnet_datapipe( datapipe_dict["sat"], datapipe_dict["pv"], keys=["gsp", "nwp", "sat", "pv"], - ) + ).map(combine_to_single_dataset) def split_dataset_dict_dp(element): diff --git a/tests/training/test_windnet.py b/tests/training/test_windnet.py index f9919d90a..4a89cffe0 100644 --- a/tests/training/test_windnet.py +++ b/tests/training/test_windnet.py @@ -17,9 +17,8 @@ def test_windnet_datapipe(configuration_filename): end_time=end_time, ) datasets = next(iter(dp)) - dataset = combine_to_single_dataset(datasets) # Need to serialize attributes to strings - dataset.to_netcdf("test.nc", mode="w", engine="h5netcdf", compute=True) + datasets.to_netcdf("test.nc", mode="w", engine="h5netcdf", compute=True) dp = windnet_netcdf_datapipe( config_filename=configuration_filename, filenames=["test.nc"], From 399a9be2f6c0c7178f1f673fb24e794b1bb504e1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Nov 2023 14:20:06 +0000 Subject: [PATCH 14/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/training/common.py | 8 +++----- ocf_datapipes/training/windnet.py | 32 +++++++++++++++---------------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index 5dd3d0f52..1e752be5f 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -1,7 +1,7 @@ """Common functionality for datapipes""" import logging -from datetime import timedelta, datetime -from typing import List, Union, Optional, Tuple, Dict +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Tuple, Union import numpy as np import xarray as xr @@ -17,10 +17,8 @@ OpenSatellite, OpenTopography, ) -from ocf_datapipes.load.gsp.database import OpenGSPFromDatabaseIterDataPipe -from ocf_datapipes.load.pv.database import OpenPVFromPVSitesDBIterDataPipe from ocf_datapipes.training.pvnet import logger -from ocf_datapipes.utils.consts import NumpyBatch, BatchKey +from ocf_datapipes.utils.consts import BatchKey, NumpyBatch logger = logging.getLogger(__name__) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index 17fd24095..e919a1b71 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -1,7 +1,7 @@ """Create the training/validation datapipe for training the PVNet Model""" import logging from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import xarray as xr @@ -14,28 +14,26 @@ from ocf_datapipes.load import ( OpenConfiguration, ) -from ocf_datapipes.utils.consts import ( - NEW_NWP_MEAN, - NEW_NWP_STD, - RSS_MEAN, - RSS_STD, - BatchKey, - NumpyBatch, -) -from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset from ocf_datapipes.training.common import ( - normalize_gsp, - normalize_pv, - concat_xr_time_utc, - fill_nans_in_pv, - fill_nans_in_arrays, - AddZeroedSatelliteData, AddZeroedNWPData, + AddZeroedSatelliteData, _get_datapipes_dict, + check_nans_in_satellite_data, + concat_xr_time_utc, construct_loctime_pipelines, + fill_nans_in_arrays, + fill_nans_in_pv, + normalize_gsp, + normalize_pv, slice_datapipes_by_time, - check_nans_in_satellite_data, ) +from ocf_datapipes.utils.consts import ( + NEW_NWP_MEAN, + NEW_NWP_STD, + RSS_MEAN, + RSS_STD, +) +from ocf_datapipes.utils.utils import uncombine_from_single_dataset xr.set_options(keep_attrs=True) logger = logging.getLogger("windnet_datapipe") From ec20ce6d35b1ea6454901317029565fbae406a8f Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 15:21:16 +0000 Subject: [PATCH 15/29] Fix imports --- ocf_datapipes/training/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index 1e752be5f..72135cd14 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -16,8 +16,9 @@ OpenPVFromNetCDF, OpenSatellite, OpenTopography, + OpenPVFromPVSitesDB, + OpenGSPFromDatabase, ) -from ocf_datapipes.training.pvnet import logger from ocf_datapipes.utils.consts import BatchKey, NumpyBatch logger = logging.getLogger(__name__) From 79d39c083ac6072d04f559b47144ea88b56e17ef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Nov 2023 15:21:50 +0000 Subject: [PATCH 16/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/training/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index 72135cd14..4590914a1 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -12,12 +12,12 @@ from ocf_datapipes.load import ( OpenConfiguration, OpenGSP, + OpenGSPFromDatabase, OpenNWP, OpenPVFromNetCDF, + OpenPVFromPVSitesDB, OpenSatellite, OpenTopography, - OpenPVFromPVSitesDB, - OpenGSPFromDatabase, ) from ocf_datapipes.utils.consts import BatchKey, NumpyBatch From ecd6791799893f3e358493494a18f61711aa3a78 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 15:27:59 +0000 Subject: [PATCH 17/29] Fix import --- ocf_datapipes/training/windnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index e919a1b71..0f45c8513 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -33,7 +33,7 @@ RSS_MEAN, RSS_STD, ) -from ocf_datapipes.utils.utils import uncombine_from_single_dataset +from ocf_datapipes.utils.utils import uncombine_from_single_dataset, combine_to_single_dataset xr.set_options(keep_attrs=True) logger = logging.getLogger("windnet_datapipe") From cc4fa02cb919b4b1aa0eb7d0ef7da4bf201f223c Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 17:22:33 +0000 Subject: [PATCH 18/29] Fix failing combining --- ocf_datapipes/utils/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ocf_datapipes/utils/utils.py b/ocf_datapipes/utils/utils.py index b1c5bc25a..170bb5103 100644 --- a/ocf_datapipes/utils/utils.py +++ b/ocf_datapipes/utils/utils.py @@ -342,6 +342,7 @@ def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset Combined dataset """ # Convert all data_arrays to datasets + new_dataset_dict = {} for key, datasets in dataset_dict.items(): new_datasets = [] for dataset in datasets: @@ -353,10 +354,11 @@ def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset new_datasets.append(dataset.to_dataset(name=key)) else: new_datasets.append(dataset) - dataset_dict[key] = new_datasets + assert isinstance(new_datasets[-1], xr.Dataset) + new_dataset_dict[key] = new_datasets # Prepend all coordinates and dimensions names with the key in the dataset_dict final_datasets_to_combined = [] - for key, datasets in dataset_dict.items(): + for key, datasets in new_dataset_dict.items(): batched_datasets = [] for dataset in datasets: dataset = dataset.rename( @@ -375,7 +377,10 @@ def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset # Serialize attributes to be JSON-seriaizable final_datasets_to_combined.append(dataset) # Combine all datasets, and append the list of datasets to the dataset_dict + for f_dset in final_datasets_to_combined: + assert isinstance(f_dset, xr.Dataset), f"Dataset is not an xr.Dataset, {type(f_dset)}" combined_dataset = xr.merge(final_datasets_to_combined) + combined_dataset.to_netcdf("combined_dataset.nc", engine="h5netcdf") # Print all attrbutes of the combined dataset return combined_dataset From 2d56248339a1c3195a7757c00e2abc9cbbdee3af Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Nov 2023 15:28:39 +0000 Subject: [PATCH 19/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/training/windnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index 0f45c8513..911b166fa 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -33,7 +33,7 @@ RSS_MEAN, RSS_STD, ) -from ocf_datapipes.utils.utils import uncombine_from_single_dataset, combine_to_single_dataset +from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset xr.set_options(keep_attrs=True) logger = logging.getLogger("windnet_datapipe") From 848822cc682a108c683952abae4db64fc0e861da Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 15 Nov 2023 17:36:18 +0000 Subject: [PATCH 20/29] Stop saving out NetCDF --- ocf_datapipes/utils/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ocf_datapipes/utils/utils.py b/ocf_datapipes/utils/utils.py index 170bb5103..f50e8124a 100644 --- a/ocf_datapipes/utils/utils.py +++ b/ocf_datapipes/utils/utils.py @@ -380,7 +380,6 @@ def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset for f_dset in final_datasets_to_combined: assert isinstance(f_dset, xr.Dataset), f"Dataset is not an xr.Dataset, {type(f_dset)}" combined_dataset = xr.merge(final_datasets_to_combined) - combined_dataset.to_netcdf("combined_dataset.nc", engine="h5netcdf") # Print all attrbutes of the combined dataset return combined_dataset From e887c58c96c943d85a9af6109d8bca72f2e6fa7e Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 16 Nov 2023 11:01:09 +0000 Subject: [PATCH 21/29] Fix tests --- ocf_datapipes/training/windnet.py | 56 +++++++++++++++++++++++++------ ocf_datapipes/utils/utils.py | 2 +- tests/utils/test_utils.py | 15 +++------ 3 files changed, 52 insertions(+), 21 deletions(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index 911b166fa..0e3b45b79 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -46,7 +46,7 @@ def scale_wind_speed_to_power(x: Union[xr.DataArray, xr.Dataset]): Roughly, double speed in m/s, and convert with the power scale Args: - x: + x: xr. Returns: @@ -60,7 +60,7 @@ def scale_wind_speed_to_power(x: Union[xr.DataArray, xr.Dataset]): @functional_datapipe("dict_datasets") class DictDatasetIterDataPipe(IterDataPipe): - """ """ + """Create a dictionary of xr.Datasets from a set of iterators""" datapipes: Tuple[IterDataPipe] length: Optional[int] @@ -87,22 +87,25 @@ def __iter__(self): @functional_datapipe("load_dict_datasets") class LoadDictDatasetIterDataPipe(IterDataPipe): - """ """ + """Load NetCDF files and split them back into individual xr.Datasets""" filenames: List[str] keys: List[str] - configuration: Configuration - def __init__(self, filenames: List[str], keys: List[str], configuration: Configuration): - """Init""" + def __init__(self, filenames: List[str], keys: List[str]): + """ + Load NetCDF files and split them back into individual xr.Datasets + + Args: + filenames: List of filesnames to load + keys: List of keys from each file to use, each key should be a dataarray in the xr.Dataset + """ super().__init__() self.keys = keys self.filenames = filenames - self.configuration = configuration def __iter__(self): - """Iter""" - # Iterate through each filename, loading it, uncombining it, and then yielding it + """Iterate through each filename, loading it, uncombining it, and then yielding it""" while True: for filename in self.filenames: dataset = xr.open_dataset(filename) @@ -375,10 +378,43 @@ def windnet_netcdf_datapipe( datapipe_dict_dp: IterDataPipe = LoadDictDatasetIterDataPipe( filenames=filenames, keys=keys, - configuration=configuration, ).map(split_dataset_dict_dp) datapipe = datapipe_dict_dp.convert_to_numpy_batch( block_nwp=block_nwp, block_sat=block_sat, configuration=configuration ) return datapipe + + +if __name__ == "__main__": + configuration_filename = "/home/jacob/Development/ocf_datapipes/tests/config/test.yaml" + start_time = datetime(1900, 1, 1) + end_time = datetime(2050, 1, 1) + dp = windnet_datapipe( + configuration_filename, + start_time=start_time, + end_time=end_time, + ) + datasets = next(iter(dp)) + print("----------------------------------------------------------") + print(datasets) + print("----------------------------------------------------------") + exit() + assert isinstance(dataset, xr.Dataset) + multiple_datasets = uncombine_from_single_dataset(dataset) + for key in multiple_datasets.keys(): + if "time_utc" in multiple_datasets[key].coords.keys(): + time_coord = "time_utc" + else: + time_coord = "target_time_utc" + for i in range(len(multiple_datasets[key][time_coord])): + # Assert that coordinates are the same + assert ( + datasets[key][i].coords.keys() + == multiple_datasets[key].isel({time_coord: i}).coords.keys() + ) + # Assert that data for each of the coords is the same + for coord_key in datasets[key][i].coords.keys(): + assert datasets[key][i][coord_key].equals( + multiple_datasets[key].isel({time_coord: i})[coord_key] + ) diff --git a/ocf_datapipes/utils/utils.py b/ocf_datapipes/utils/utils.py index f50e8124a..66522b334 100644 --- a/ocf_datapipes/utils/utils.py +++ b/ocf_datapipes/utils/utils.py @@ -402,7 +402,7 @@ def uncombine_from_single_dataset(combined_dataset: xr.Dataset) -> dict[str, xr. dataset_dims = list(dataset.coords) for dim in dataset_dims: if f"{key}__" not in dim: - dataset = dataset.drop(dim) + dataset: xr.DataArray = dataset.drop(dim) dataset = dataset.rename( {dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords} ) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 4c5c8819a..df537f047 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -21,8 +21,7 @@ def test_combine_uncombine_from_single_dataset(configuration_filename): start_time=start_time, end_time=end_time, ) - datasets = next(iter(dp)) - dataset: xr.Dataset = combine_to_single_dataset(datasets) + dataset: xr.Dataset = next(iter(dp)) assert isinstance(dataset, xr.Dataset) multiple_datasets = uncombine_from_single_dataset(dataset) for key in multiple_datasets.keys(): @@ -31,13 +30,9 @@ def test_combine_uncombine_from_single_dataset(configuration_filename): else: time_coord = "target_time_utc" for i in range(len(multiple_datasets[key][time_coord])): - # Assert that coordinates are the same - assert ( - datasets[key][i].coords.keys() - == multiple_datasets[key].isel({time_coord: i}).coords.keys() - ) # Assert that data for each of the coords is the same - for coord_key in datasets[key][i].coords.keys(): - assert datasets[key][i][coord_key].equals( - multiple_datasets[key].isel({time_coord: i})[coord_key] + for coord_key in multiple_datasets[key][i].coords.keys(): + np.testing.assert_equal( + multiple_datasets[key].isel({time_coord: i})[coord_key].values, + dataset[key][i][f"{key}__{coord_key}"].values, ) From f91073958600c41c98a759b5713b7bb3a033839c Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 16 Nov 2023 11:11:06 +0000 Subject: [PATCH 22/29] Linting fixes --- ocf_datapipes/training/windnet.py | 39 ++----------------------------- 1 file changed, 2 insertions(+), 37 deletions(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index 0e3b45b79..be66ba312 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -5,7 +5,6 @@ import xarray as xr -xr.set_options(keep_attrs=True) from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterableWrapper, IterDataPipe @@ -46,10 +45,10 @@ def scale_wind_speed_to_power(x: Union[xr.DataArray, xr.Dataset]): Roughly, double speed in m/s, and convert with the power scale Args: - x: xr. + x: xr.DataArray or xr.Dataset containing wind speed Returns: - + Rescaled wind speed to MWh roughly """ # Convert knots to m/s x = x * 0.514444 @@ -384,37 +383,3 @@ def windnet_netcdf_datapipe( ) return datapipe - - -if __name__ == "__main__": - configuration_filename = "/home/jacob/Development/ocf_datapipes/tests/config/test.yaml" - start_time = datetime(1900, 1, 1) - end_time = datetime(2050, 1, 1) - dp = windnet_datapipe( - configuration_filename, - start_time=start_time, - end_time=end_time, - ) - datasets = next(iter(dp)) - print("----------------------------------------------------------") - print(datasets) - print("----------------------------------------------------------") - exit() - assert isinstance(dataset, xr.Dataset) - multiple_datasets = uncombine_from_single_dataset(dataset) - for key in multiple_datasets.keys(): - if "time_utc" in multiple_datasets[key].coords.keys(): - time_coord = "time_utc" - else: - time_coord = "target_time_utc" - for i in range(len(multiple_datasets[key][time_coord])): - # Assert that coordinates are the same - assert ( - datasets[key][i].coords.keys() - == multiple_datasets[key].isel({time_coord: i}).coords.keys() - ) - # Assert that data for each of the coords is the same - for coord_key in datasets[key][i].coords.keys(): - assert datasets[key][i][coord_key].equals( - multiple_datasets[key].isel({time_coord: i})[coord_key] - ) From 75bb0f76716d66e1c5ae88ea9397cf2caf26761a Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 16 Nov 2023 11:14:50 +0000 Subject: [PATCH 23/29] Fix lint --- ocf_datapipes/training/windnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index be66ba312..8ffb66ba1 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -118,7 +118,7 @@ def __iter__(self): @functional_datapipe("convert_to_numpy_batch") class ConvertToNumpyBatchIterDataPipe(IterDataPipe): - """ """ + """Converts Xarray Dataset to Numpy Batch""" def __init__( self, From 6b43db6b54f717f9aa30cc5df6c9944c8682c6dc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Nov 2023 11:13:24 +0000 Subject: [PATCH 24/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/training/windnet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index 8ffb66ba1..d0eac65a2 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -4,7 +4,6 @@ from typing import List, Optional, Tuple, Union import xarray as xr - from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterableWrapper, IterDataPipe From f0d43f359c8a60606cda7891d9a14127d11b36f0 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 16 Nov 2023 11:16:29 +0000 Subject: [PATCH 25/29] Lint fixes --- ocf_datapipes/training/windnet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index d0eac65a2..89e0c07b8 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -96,7 +96,8 @@ def __init__(self, filenames: List[str], keys: List[str]): Args: filenames: List of filesnames to load - keys: List of keys from each file to use, each key should be a dataarray in the xr.Dataset + keys: List of keys from each file to use, each key should be a + dataarray in the xr.Dataset """ super().__init__() self.keys = keys From 6b25ae9e2ad0c434c675837045d4cbebc3391d83 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 16 Nov 2023 16:40:02 +0000 Subject: [PATCH 26/29] Fix imports --- ocf_datapipes/training/common.py | 485 +++++++++++---------------- ocf_datapipes/training/pvnet.py | 544 +------------------------------ 2 files changed, 207 insertions(+), 822 deletions(-) diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index 4590914a1..51b088093 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -152,295 +152,6 @@ def open_and_return_datapipes( return used_datapipes -def get_and_return_overlapping_time_periods_and_t0(used_datapipes: dict, key_for_t0: str = "gsp"): - """ - Takes datapipes and obtains the overlapping time periods + t0 time datapipes - - Args: - used_datapipes: Dictionary of datapipes to compute the time intersection of - key_for_t0: Key to use for the t0 datapipe - - Returns: - Dictionary of datapipes with the proper time slices selected - """ - datapipes_for_time_periods = [] # Using later to compute intersections - datapipes_to_return = {} # Returned along with original ones - t0_datapipe = None - configuration = used_datapipes.pop("config") - for key, datapipe in used_datapipes.items(): - if "topo" in key: - continue - if key_for_t0 in key: - forked_datapipes = datapipe.fork(3, buffer_size=100) - t0_datapipe = forked_datapipes[2] - else: - forked_datapipes = datapipe.fork(2, buffer_size=100) - datapipes_to_return[key] = forked_datapipes[0] - if "nwp" == key: - time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods( - sample_period_duration=timedelta(hours=3), # Init times are 3 hours apart - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - time_dim="init_time_utc", - ) - datapipes_for_time_periods.append(time_periods_datapipe) - - if "sat" == key: - time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.satellite.history_minutes - ), - forecast_duration=timedelta(minutes=0), - ) - datapipes_for_time_periods.append(time_periods_datapipe) - - if "hrv" == key: - time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=0), - ) - datapipes_for_time_periods.append(time_periods_datapipe) - - if "pv" == key: - time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.pv.forecast_minutes), - ) - datapipes_for_time_periods.append(time_periods_datapipe) - if "gsp" == key: - time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - ) - datapipes_for_time_periods.append(time_periods_datapipe) - - # Now have the forked ones - # find joint overlapping timer periods - logger.debug("Getting joint time periods") - overlapping_datapipe = datapipes_for_time_periods[0].select_overlapping_time_slice( - secondary_datapipes=datapipes_for_time_periods[1:], - ) - - # select time periods - t0_datapipe = t0_datapipe.select_time_periods(time_periods=overlapping_datapipe) - - num_t0_datapipes = len(datapipes_to_return.keys()) # One for each input - t0_datapipes = t0_datapipe.select_t0_time(return_all_times=False).fork( - num_t0_datapipes, buffer_size=100 - ) - - for i, key in enumerate(list(datapipes_to_return.keys())): - datapipes_to_return[key + "_t0"] = t0_datapipes[i] - - # Re-add config for later - datapipes_to_return["config"] = configuration - if "topo" in used_datapipes.keys(): - datapipes_to_return["topo"] = used_datapipes["topo"] - return datapipes_to_return - - -def add_selected_time_slices_from_datapipes(used_datapipes: dict): - """ - Takes datapipes and t0 datapipes and returns the sliced datapipes - - Args: - used_datapipes: Dictionary of used datapipes and t0 ones - - Returns: - Dictionary of datapipes after the time slices are selected - """ - datapipes_to_return = {} # Returned along with original ones - configuration = used_datapipes.pop("config") - for key, datapipe in used_datapipes.items(): - if "topo" in key: - continue - if "_t0" in key: - continue - if "nwp" == key: - datapipes_to_return[key] = datapipe.convert_to_nwp_target_time( - t0_datapipe=used_datapipes[key + "_t0"], - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - ) - - if "sat" == key: - datapipes_to_return[key] = datapipe.select_time_slice( - t0_datapipe=used_datapipes[key + "_t0"], - history_duration=timedelta( - minutes=configuration.input_data.satellite.history_minutes - ), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ) - - if "hrv" == key: - datapipes_to_return[key] = datapipe.select_time_slice( - t0_datapipe=used_datapipes[key + "_t0"], - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ) - - if "pv" == key: - pv_1, pv_2 = used_datapipes[key + "_t0"].fork(2) - pv_dp1, pv_dp2 = datapipe.fork(2) - datapipes_to_return[key] = pv_dp1.select_time_slice( - t0_datapipe=pv_1, - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ) - datapipes_to_return[key + "_future"] = pv_dp2.select_time_slice( - t0_datapipe=pv_2, - history_duration=timedelta(minutes=0), - forecast_duration=timedelta(minutes=configuration.input_data.pv.forecast_minutes), - sample_period_duration=timedelta(minutes=5), - ) - - if "gsp" == key: - gsp_1, gsp_2 = used_datapipes[key + "_t0"].fork(2) - gsp_dp1, gsp_dp2 = datapipe.fork(2) - datapipes_to_return[key] = gsp_dp1.select_time_slice( - t0_datapipe=gsp_1, - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=30), - ) - datapipes_to_return[key + "_future"] = gsp_dp2.select_time_slice( - t0_datapipe=gsp_2, - history_duration=timedelta(minutes=0), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - sample_period_duration=timedelta(minutes=30), - ) - if "topo" in used_datapipes.keys(): - datapipes_to_return["topo"] = used_datapipes["topo"] - datapipes_to_return["config"] = configuration - return datapipes_to_return - - -def create_t0_and_loc_datapipes( - datapipes_dict: dict, - configuration: Configuration, - key_for_t0: str = "gsp", - shuffle: bool = True, - nwp_max_dropout_minutes: int = 0, - nwp_max_staleness_minutes: int = 180, -): - """ - Takes source datapipes and returns datapipes of appropriate sample pairs of locations and times. - - The (location, t0) pairs are sampled without replacement. - - Args: - datapipes_dict: Dictionary of datapipes of input sources for which we want to select - appropriate location and times. - configuration: Configuration object for inputs. - key_for_t0: Key to use for the t0 datapipe. Must be "gsp" or "pv". - shuffle: Whether to use the internal shuffle function when yielding location times. Else - location times will be heavily ordered. - nwp_max_dropout_minutes: If using dropout on NWP, sometimes we have to go back to previous - NWP init time. In order to accomodate for this possibility in selecting times, set - `nwp_max_dropout_minutes` as the max NWP dropout delay you plan to use. - nwp_max_staleness_minutes: Sets a limit on how stale an NWP init time is allowed to be - whilst still being used to costruct an example - - Returns: - location datapipe, t0 datapipe - - """ - assert key_for_t0 in datapipes_dict - assert key_for_t0 in ["gsp", "pv"] - assert nwp_max_staleness_minutes >= nwp_max_dropout_minutes - - contiguous_time_datapipes = [] # Used to store contiguous time periods from each data source - - datapipes_dict[key_for_t0], key_datapipe = datapipes_dict[key_for_t0].fork(2, buffer_size=5) - - for key in datapipes_dict.keys(): - if key in ["topo"]: - continue - - elif key == "nwp": - datapipes_dict["nwp"], datapipe_copy = datapipes_dict["nwp"].fork(2, buffer_size=5) - - # NWP is a forecast product so gets its own contiguous function - time_periods = datapipe_copy.get_contiguous_time_periods_nwp( - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - max_staleness=timedelta(minutes=nwp_max_staleness_minutes), - max_dropout=timedelta(minutes=nwp_max_dropout_minutes), - time_dim="init_time_utc", - ) - - contiguous_time_datapipes.append(time_periods) - - else: - if key == "sat": - sample_frequency = 5 - history_duration = configuration.input_data.satellite.history_minutes - forecast_duration = 0 - time_dim = "time_utc" - - elif key == "hrv": - sample_frequency = 5 - history_duration = configuration.input_data.hrvsatellite.history_minutes - forecast_duration = 0 - time_dim = "time_utc" - - elif key == "pv": - sample_frequency = 5 - history_duration = configuration.input_data.pv.history_minutes - forecast_duration = configuration.input_data.pv.forecast_minutes - time_dim = "time_utc" - - elif key == "gsp": - sample_frequency = 30 - history_duration = configuration.input_data.gsp.history_minutes - forecast_duration = configuration.input_data.gsp.forecast_minutes - time_dim = "time_utc" - - else: - raise ValueError(f"Unexpected key: {key}") - - datapipes_dict[key], datapipe_copy = datapipes_dict[key].fork(2, buffer_size=5) - - time_periods = datapipe_copy.get_contiguous_time_periods( - sample_period_duration=timedelta(minutes=sample_frequency), - history_duration=timedelta(minutes=history_duration), - forecast_duration=timedelta(minutes=forecast_duration), - time_dim=time_dim, - ) - - contiguous_time_datapipes.append(time_periods) - - # Find joint overlapping contiguous time periods - if len(contiguous_time_datapipes) > 1: - logger.debug("Getting joint time periods") - overlapping_datapipe = contiguous_time_datapipes[0].select_overlapping_time_slice( - secondary_datapipes=contiguous_time_datapipes[1:], - ) - else: - logger.debug("Skipping getting joint time periods") - overlapping_datapipe = contiguous_time_datapipes[0] - - # Select time periods and set length - key_datapipe = key_datapipe.select_time_periods(time_periods=overlapping_datapipe) - - t0_loc_datapipe = key_datapipe.select_loc_and_t0(return_all=True, shuffle=shuffle) - - location_pipe, t0_datapipe = t0_loc_datapipe.unzip(sequence_length=2) - - return location_pipe, t0_datapipe - - def normalize_gsp(x): """Normalize the GSP data @@ -1001,3 +712,199 @@ def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: raise ValueError("Found nans values in satellite data") return batch + + +def add_selected_time_slices_from_datapipes(used_datapipes: dict): + """ + Takes datapipes and t0 datapipes and returns the sliced datapipes + + Args: + used_datapipes: Dictionary of used datapipes and t0 ones + + Returns: + Dictionary of datapipes after the time slices are selected + """ + datapipes_to_return = {} # Returned along with original ones + configuration = used_datapipes.pop("config") + for key, datapipe in used_datapipes.items(): + if "topo" in key: + continue + if "_t0" in key: + continue + if "nwp" == key: + datapipes_to_return[key] = datapipe.convert_to_nwp_target_time( + t0_datapipe=used_datapipes[key + "_t0"], + sample_period_duration=timedelta(hours=1), + history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), + forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), + ) + + if "sat" == key: + datapipes_to_return[key] = datapipe.select_time_slice( + t0_datapipe=used_datapipes[key + "_t0"], + history_duration=timedelta( + minutes=configuration.input_data.satellite.history_minutes + ), + forecast_duration=timedelta(minutes=0), + sample_period_duration=timedelta(minutes=5), + ) + + if "hrv" == key: + datapipes_to_return[key] = datapipe.select_time_slice( + t0_datapipe=used_datapipes[key + "_t0"], + history_duration=timedelta( + minutes=configuration.input_data.hrvsatellite.history_minutes + ), + forecast_duration=timedelta(minutes=0), + sample_period_duration=timedelta(minutes=5), + ) + + if "pv" == key: + pv_1, pv_2 = used_datapipes[key + "_t0"].fork(2) + pv_dp1, pv_dp2 = datapipe.fork(2) + datapipes_to_return[key] = pv_dp1.select_time_slice( + t0_datapipe=pv_1, + history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), + forecast_duration=timedelta(minutes=0), + sample_period_duration=timedelta(minutes=5), + ) + datapipes_to_return[key + "_future"] = pv_dp2.select_time_slice( + t0_datapipe=pv_2, + history_duration=timedelta(minutes=0), + forecast_duration=timedelta(minutes=configuration.input_data.pv.forecast_minutes), + sample_period_duration=timedelta(minutes=5), + ) + + if "gsp" == key: + gsp_1, gsp_2 = used_datapipes[key + "_t0"].fork(2) + gsp_dp1, gsp_dp2 = datapipe.fork(2) + datapipes_to_return[key] = gsp_dp1.select_time_slice( + t0_datapipe=gsp_1, + history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), + forecast_duration=timedelta(minutes=0), + sample_period_duration=timedelta(minutes=30), + ) + datapipes_to_return[key + "_future"] = gsp_dp2.select_time_slice( + t0_datapipe=gsp_2, + history_duration=timedelta(minutes=0), + forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), + sample_period_duration=timedelta(minutes=30), + ) + if "topo" in used_datapipes.keys(): + datapipes_to_return["topo"] = used_datapipes["topo"] + datapipes_to_return["config"] = configuration + return datapipes_to_return + + +def create_t0_and_loc_datapipes( + datapipes_dict: dict, + configuration: Configuration, + key_for_t0: str = "gsp", + shuffle: bool = True, + nwp_max_dropout_minutes: int = 0, + nwp_max_staleness_minutes: int = 180, +): + """ + Takes source datapipes and returns datapipes of appropriate sample pairs of locations and times. + + The (location, t0) pairs are sampled without replacement. + + Args: + datapipes_dict: Dictionary of datapipes of input sources for which we want to select + appropriate location and times. + configuration: Configuration object for inputs. + key_for_t0: Key to use for the t0 datapipe. Must be "gsp" or "pv". + shuffle: Whether to use the internal shuffle function when yielding location times. Else + location times will be heavily ordered. + nwp_max_dropout_minutes: If using dropout on NWP, sometimes we have to go back to previous + NWP init time. In order to accomodate for this possibility in selecting times, set + `nwp_max_dropout_minutes` as the max NWP dropout delay you plan to use. + nwp_max_staleness_minutes: Sets a limit on how stale an NWP init time is allowed to be + whilst still being used to costruct an example + + Returns: + location datapipe, t0 datapipe + + """ + assert key_for_t0 in datapipes_dict + assert key_for_t0 in ["gsp", "pv"] + assert nwp_max_staleness_minutes >= nwp_max_dropout_minutes + + contiguous_time_datapipes = [] # Used to store contiguous time periods from each data source + + datapipes_dict[key_for_t0], key_datapipe = datapipes_dict[key_for_t0].fork(2, buffer_size=5) + + for key in datapipes_dict.keys(): + if key in ["topo"]: + continue + + elif key == "nwp": + datapipes_dict["nwp"], datapipe_copy = datapipes_dict["nwp"].fork(2, buffer_size=5) + + # NWP is a forecast product so gets its own contiguous function + time_periods = datapipe_copy.get_contiguous_time_periods_nwp( + history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), + max_staleness=timedelta(minutes=nwp_max_staleness_minutes), + max_dropout=timedelta(minutes=nwp_max_dropout_minutes), + time_dim="init_time_utc", + ) + + contiguous_time_datapipes.append(time_periods) + + else: + if key == "sat": + sample_frequency = 5 + history_duration = configuration.input_data.satellite.history_minutes + forecast_duration = 0 + time_dim = "time_utc" + + elif key == "hrv": + sample_frequency = 5 + history_duration = configuration.input_data.hrvsatellite.history_minutes + forecast_duration = 0 + time_dim = "time_utc" + + elif key == "pv": + sample_frequency = 5 + history_duration = configuration.input_data.pv.history_minutes + forecast_duration = configuration.input_data.pv.forecast_minutes + time_dim = "time_utc" + + elif key == "gsp": + sample_frequency = 30 + history_duration = configuration.input_data.gsp.history_minutes + forecast_duration = configuration.input_data.gsp.forecast_minutes + time_dim = "time_utc" + + else: + raise ValueError(f"Unexpected key: {key}") + + datapipes_dict[key], datapipe_copy = datapipes_dict[key].fork(2, buffer_size=5) + + time_periods = datapipe_copy.get_contiguous_time_periods( + sample_period_duration=timedelta(minutes=sample_frequency), + history_duration=timedelta(minutes=history_duration), + forecast_duration=timedelta(minutes=forecast_duration), + time_dim=time_dim, + ) + + contiguous_time_datapipes.append(time_periods) + + # Find joint overlapping contiguous time periods + if len(contiguous_time_datapipes) > 1: + logger.debug("Getting joint time periods") + overlapping_datapipe = contiguous_time_datapipes[0].select_overlapping_time_slice( + secondary_datapipes=contiguous_time_datapipes[1:], + ) + else: + logger.debug("Skipping getting joint time periods") + overlapping_datapipe = contiguous_time_datapipes[0] + + # Select time periods and set length + key_datapipe = key_datapipe.select_time_periods(time_periods=overlapping_datapipe) + + t0_loc_datapipe = key_datapipe.select_loc_and_t0(return_all=True, shuffle=shuffle) + + location_pipe, t0_datapipe = t0_loc_datapipe.unzip(sequence_length=2) + + return location_pipe, t0_datapipe diff --git a/ocf_datapipes/training/pvnet.py b/ocf_datapipes/training/pvnet.py index 537c1e039..62fa9e4a9 100644 --- a/ocf_datapipes/training/pvnet.py +++ b/ocf_datapipes/training/pvnet.py @@ -12,8 +12,17 @@ from ocf_datapipes.config.model import Configuration from ocf_datapipes.load import OpenGSPFromDatabase, OpenPVFromPVSitesDB from ocf_datapipes.training.common import ( - create_t0_and_loc_datapipes, - open_and_return_datapipes, + AddZeroedNWPData, + AddZeroedSatelliteData, + _get_datapipes_dict, + check_nans_in_satellite_data, + concat_xr_time_utc, + construct_loctime_pipelines, + fill_nans_in_arrays, + fill_nans_in_pv, + normalize_gsp, + normalize_pv, + slice_datapipes_by_time, ) from ocf_datapipes.utils.consts import ( NEW_NWP_MEAN, @@ -28,537 +37,6 @@ logger = logging.getLogger("pvnet_datapipe") -def normalize_gsp(x): - """Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.effective_capacity_mwp - - -def normalize_pv(x): - """Normalize the PV data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return (x / x.nominal_capacity_wp).clip(None, 5) - - -def production_sat_scale(x): - """Scale the production satellite data - - Args: - x: Input DataArray - - Returns: - Scaled DataArray - """ - return x / 1024 - - -def concat_xr_time_utc(gsp_dataarrays: List[xr.DataArray]): - """This function is used to combine the split history and future gsp/pv dataarrays. - - These are split inside the `slice_datapipes_by_time()` function below. - - Splitting them inside that function allows us to apply dropout to the - history GSP/PV whilst leaving the future GSP/PV without NaNs. - - We recombine the history and future with this function to allow us to use the - `MergeNumpyModalities()` datapipe without redefining the BatchKeys. - - The `pvnet` model was also written to use a GSP/PV array which has historical and future - and to split it out. These maintains that assumption. - """ - return xr.concat(gsp_dataarrays, dim="time_utc") - - -def gsp_drop_national(x: Union[xr.DataArray, xr.Dataset]): - """Drop entries for national PV output - - Args: - x: Data source of gsp data - - Returns: - Filtered data source - """ - return x.where(x.gsp_id != 0, drop=True) - - -@functional_datapipe("pvnet_select_pv_by_ml_id") -class PVNetSelectPVbyMLIDIterDataPipe(IterDataPipe): - """Select specific set of PV systems by ML ID.""" - - def __init__(self, source_datapipe: IterDataPipe, ml_ids: np.array): - """Select specific set of PV systems by ML ID. - - Args: - source_datapipe: Datapipe emitting PV xarray data - ml_ids: List-like of ML IDs to select - - Returns: - Filtered data source - """ - self.source_datapipe = source_datapipe - self.ml_ids = ml_ids - - def __iter__(self): - for x in self.source_datapipe: - # Check for missing IDs - ml_ids_not_in_data = ~np.isin(self.ml_ids, x.ml_id) - if ml_ids_not_in_data.any(): - missing_ml_ids = np.array(self.ml_ids)[ml_ids_not_in_data] - logger.warning( - f"The following ML IDs were mising in the PV site-level input data: " - f"{missing_ml_ids}. The values for these IDs will be set to NaN." - ) - - x_filtered = ( - # Many ML-IDs are null, so filter first - x.where(~x.ml_id.isnull(), drop=True) - # Swap dimensions so we can select by ml_id coordinate - .swap_dims({"pv_system_id": "ml_id"}) - # Select IDs - missing IDs are given NaN values - .reindex(ml_id=self.ml_ids) - # Swap back dimensions - .swap_dims({"ml_id": "pv_system_id"}) - ) - yield x_filtered - - -def fill_nans_in_pv(x: Union[xr.DataArray, xr.Dataset]): - """Fill NaNs in PV data with the value -1 - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x.fillna(-1) - - -def fill_nans_in_arrays(batch: NumpyBatch) -> NumpyBatch: - """Fills all NaN values in each np.ndarray in the batch dictionary with zeros. - - Operation is performed in-place on the batch. - """ - logger.info("Filling Nans with zeros") - for k, v in batch.items(): - if isinstance(v, np.ndarray): - np.nan_to_num(v, copy=False, nan=0.0) - return batch - - -class AddZeroedSatelliteData: - """A callable class used to add zeroed-out satellite data to batches of data. - - This is useful - to speed up batch loading if pre-training the output part of the network without satellite - inputs. - """ - - def __init__(self, configuration: Configuration, is_hrv: bool = False): - """A callable class used to add zeroed-out satellite data to batches of data. - - Args: - configuration: Configuration object - is_hrv: If False, non-HRV data is added by called function, else HRV. - """ - - self.configuration = configuration - self.is_hrv = is_hrv - - def __call__(self, batch: NumpyBatch) -> NumpyBatch: - """Add zeroed-out satellite data to batch with shape accoriding to supplied configuration. - - Batch is modified in-place and returned. - - Args: - batch: Numpy batch of input data. - """ - - variable = "hrvsatellite" if self.is_hrv else "satellite" - - satellite_config = getattr(self.configuration.input_data, variable) - - n_channels = len(getattr(satellite_config, f"{variable}_channels")) - height = getattr(satellite_config, f"{variable}_image_size_pixels_height") - width = getattr(satellite_config, f"{variable}_image_size_pixels_width") - - sequence_len = satellite_config.history_minutes // 5 + 1 - 3 - - batch[getattr(BatchKey, f"{variable}_actual")] = np.zeros( - (sequence_len, n_channels, height, width) - ) - - return batch - - -class AddZeroedNWPData: - """A callable class used to add zeroed-out NWP data to batches of data. - - This is useful to speed up batch loading if pre-training the output part of the network without - NWP inputs. - """ - - def __init__(self, configuration: Configuration): - """A callable class used to add zeroed-out NWP data to batches of data. - - Args: - configuration: Configuration object - """ - self.configuration = configuration - - def __call__(self, batch: NumpyBatch) -> NumpyBatch: - """Add zeroed-out NWP data to batch with shape accoriding to supplied configuration. - - Batch is modified in-place and returned. - - Args: - batch: Numpy batch of input data. - """ - - config = self.configuration.input_data.nwp - - n_channels = len(config.nwp_channels) - height = config.nwp_image_size_pixels_height - width = config.nwp_image_size_pixels_width - - sequence_len = config.history_minutes // 60 + config.forecast_minutes // 60 + 1 - - batch[BatchKey.nwp] = np.zeros((sequence_len, n_channels, height, width)) - - return batch - - -class DatapipeKeyForker: - """ "Internal helper function to track forking of a datapipe.""" - - def __init__(self, keys: List, datapipe: IterDataPipe): - """Internal helper function to track forking of a datapipe. - - As forks are returned, this object tracks the keys left and returns the final copy of the - datapipe when the last key is requested. This makes multiple forking easier and ensures - closure. - - Args: - keys: List of keys for which datapipe duplication is required. - datapipe: Datapipe which will be forked for each ket - """ - self.keys_left = keys - self.datapipe = datapipe - - def __call__(self, key): - """ "Returns a fork of `self.datapipe` and tracks a the keys left to ensure closure. - - Args: - key: key to remove from `self.keys_left`. If `key` is None then an extra copy is made - without affecting `self.keys_left`. - """ - if len(self.keys_left) == 0: - raise ValueError(f"No keys left when requested key : {key}") - if key is not None: - self.keys_left.remove(key) - if len(self.keys_left) > 0: - self.datapipe, return_datapipe = self.datapipe.fork(2, buffer_size=5) - else: - return_datapipe = self.datapipe - return return_datapipe - - def close(self): - """Asserts that the keys have all been used.""" - assert len(self.keys_left) == 0 - - -def _get_datapipes_dict( - config_filename: str, - block_sat: bool, - block_nwp: bool, - production: bool = False, -): - # Load datasets - datapipes_dict = open_and_return_datapipes( - configuration_filename=config_filename, - use_gsp=(not production), - use_pv=(not production), - use_sat=not block_sat, # Only loaded if we aren't replacing them with zeros - use_hrv=False, - use_nwp=not block_nwp, # Only loaded if we aren't replacing them with zeros - use_topo=False, - production=production, - ) - - config: Configuration = datapipes_dict["config"] - - if production: - datapipes_dict["gsp"] = OpenGSPFromDatabase().add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=config.input_data.gsp.history_minutes), - ) - if "sat" in datapipes_dict: - datapipes_dict["sat"] = datapipes_dict["sat"].map(production_sat_scale) - if "pv" in datapipes_dict: - datapipes_dict["pv"] = OpenPVFromPVSitesDB(config.input_data.pv.history_minutes) - - if "pv" in datapipes_dict and config.input_data.pv.pv_ml_ids != []: - datapipes_dict["pv"] = datapipes_dict["pv"].pvnet_select_pv_by_ml_id( - config.input_data.pv.pv_ml_ids - ) - - return datapipes_dict - - -def construct_loctime_pipelines( - config_filename: str, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, - block_sat: bool = False, - block_nwp: bool = False, -) -> Tuple[IterDataPipe, IterDataPipe]: - """Construct location and time pipelines for the input data config file. - - Args: - config_filename: Path to config file. - start_time: Minimum time for time datapipe. - end_time: Maximum time for time datapipe. - block_sat: Whether to load zeroes for satellite data. - block_nwp: Whether to load zeroes for NWP data. - """ - - datapipes_dict = _get_datapipes_dict( - config_filename, - block_sat, - block_nwp, - ) - - # Pull out config file - config = datapipes_dict.pop("config") - - # We sample time and space of other data using GSP time and space coordinates, so filter GSP - # data first amd this is carried through - datapipes_dict["gsp"] = datapipes_dict["gsp"].map(gsp_drop_national) - if (start_time is not None) or (end_time is not None): - datapipes_dict["gsp"] = datapipes_dict["gsp"].select_train_test_time(start_time, end_time) - - # Get overlapping time periods - location_pipe, t0_datapipe = create_t0_and_loc_datapipes( - datapipes_dict, - configuration=config, - key_for_t0="gsp", - shuffle=True, - nwp_max_dropout_minutes=180, - # Sometimes the forecast is only 4/day so 6 hour intervals - then we add 3-hour dropout - nwp_max_staleness_minutes=60 * 9, - ) - - return location_pipe, t0_datapipe - - -def minutes(num_mins: int): - """Timedelta of a number of minutes. - - Args: - num_mins: Minutes timedelta. - """ - return timedelta(minutes=num_mins) - - -def slice_datapipes_by_time( - datapipes_dict: Dict, - t0_datapipe: IterDataPipe, - configuration: Configuration, - production: bool = False, -) -> None: - """ - Modifies a dictionary of datapipes in-place to yield samples for given times t0. - - The NWP data* will be at least 90 minutes stale (i.e. as if it takes 90 minutes for the foreast - to become available). - - The satellite data* is shaped so that the most recent can be 15 minutes before t0. However, 50% - of the time dropout is applied so that the most recent field is between 45 and 20 minutes before - t0. When dropped out like this, the values after this selected dropout time are set to NaN. - - The HRV data* is similar to the satellite data and if both are included they drop out - simulataneously. - - The GSP data is split into "gsp" and "gsp_future" keys. 10% of the time the gsp value for time - t0, which occurs under the "gsp" key, is set to NaN - - The PV data* is also split it "pv" and "pv_future" keys. - - * if included - - n.b. PV and HRV are included in this function, but not yet in the rest of the pvnet pipeline. - This is mostly for demonstratio purposes of how the dropout might be applied. - - Args: - datapipes_dict: Dictionary of used datapipes and t0 ones - t0_datapipe: Datapipe which yields t0 times for sample - configuration: Configuration object. - production: Whether constucting pipeline for production inference. No dropout is used if - True. - - """ - - conf_in = configuration.input_data - - # Use DatapipeKeyForker to avoid forking t0_datapipe too many times, or leaving any forks unused - fork_keys = {k for k in datapipes_dict.keys() if k not in ["topo"]} - get_t0_datapipe = DatapipeKeyForker(fork_keys, t0_datapipe) - - sat_and_hrv_dropout_kwargs = dict( - # Satellite is either 30 minutes or 60 minutes delayed in production. Match during training - dropout_timedeltas=[minutes(-60), minutes(-30)], - dropout_frac=0 if production else 1.0, - ) - - sat_delay = minutes(-configuration.input_data.satellite.live_delay_minutes) - - if "nwp" in datapipes_dict: - datapipes_dict["nwp"] = datapipes_dict["nwp"].convert_to_nwp_target_time_with_dropout( - t0_datapipe=get_t0_datapipe("nwp"), - sample_period_duration=minutes(60), - history_duration=minutes(conf_in.nwp.history_minutes), - forecast_duration=minutes(conf_in.nwp.forecast_minutes), - # The NWP forecast will always be at least 180 minutes stale - dropout_timedeltas=[minutes(-180)], - dropout_frac=0 if production else 1.0, - ) - - if "sat" in datapipes_dict: - # Take time slices of sat data - datapipes_dict["sat"] = datapipes_dict["sat"].select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(5), - interval_start=minutes(-conf_in.satellite.history_minutes), - interval_end=sat_delay, - fill_selection=production, - max_steps_gap=2, - ) - - # Generate randomly sampled dropout times - sat_dropout_time_datapipe = get_t0_datapipe("sat").select_dropout_time( - **sat_and_hrv_dropout_kwargs - ) - - if "hrv" in datapipes_dict: - # Make dropout-time copy for hrv if included in data. - # HRV and non-HRV will dropout simultaneously. - sat_dropout_time_datapipe, hrv_dropout_time_datapipe = sat_dropout_time_datapipe.fork( - 2, buffer_size=5 - ) - - # Apply the dropout - datapipes_dict["sat"] = datapipes_dict["sat"].apply_dropout_time( - dropout_time_datapipe=sat_dropout_time_datapipe, - ) - - if "hrv" in datapipes_dict: - if "sat" not in datapipes_dict: - # Generate randomly sampled dropout times - # This is shared with sat if sat included - hrv_dropout_time_datapipe = get_t0_datapipe(None).select_dropout_time( - **sat_and_hrv_dropout_kwargs - ) - - datapipes_dict["hrv"] = datapipes_dict["hrv"].select_time_slice( - t0_datapipe=get_t0_datapipe("hrv"), - sample_period_duration=minutes(5), - interval_start=minutes(-conf_in.hrvsatellite.history_minutes), - interval_end=sat_delay, - fill_selection=production, - max_steps_gap=2, - ) - - # Apply the dropout - datapipes_dict["hrv"] = datapipes_dict["hrv"].apply_dropout_time( - dropout_time_datapipe=hrv_dropout_time_datapipe, - ) - - if "pv" in datapipes_dict: - datapipes_dict["pv"], dp = datapipes_dict["pv"].fork(2, buffer_size=5) - - datapipes_dict["pv_future"] = dp.select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(5), - interval_start=minutes(5), - interval_end=minutes(conf_in.pv.forecast_minutes), - fill_selection=production, - ) - - datapipes_dict["pv"] = datapipes_dict["pv"].select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(5), - interval_start=minutes(-conf_in.pv.history_minutes), - interval_end=minutes(0), - fill_selection=production, - ) - - # Dropout on the PV, but not the future PV - pv_dropout_time_datapipe = get_t0_datapipe("pv").select_dropout_time( - # All PV data could be delayed by up to 30 minutes - # (this does not stem from production - just setting for now) - dropout_timedeltas=[minutes(m) for m in range(-30, 0, 5)], - dropout_frac=0.1 if production else 1, - ) - - datapipes_dict["pv"] = datapipes_dict["pv"].apply_dropout_time( - dropout_time_datapipe=pv_dropout_time_datapipe, - ) - - # Apply extra PV dropout using different delays per system and droping out entire PV systems - # independently - if not production: - datapipes_dict["pv"].apply_pv_dropout( - system_dropout_fractions=np.linspace(0, 0.2, 100), - system_dropout_timedeltas=[minutes(m) for m in [-15, -10, -5, 0]], - ) - - if "gsp" in datapipes_dict: - datapipes_dict["gsp"], dp = datapipes_dict["gsp"].fork(2, buffer_size=5) - - datapipes_dict["gsp_future"] = dp.select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(30), - interval_start=minutes(30), - interval_end=minutes(conf_in.gsp.forecast_minutes), - fill_selection=production, - ) - - datapipes_dict["gsp"] = datapipes_dict["gsp"].select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(30), - interval_start=-minutes(conf_in.gsp.history_minutes), - interval_end=minutes(0), - fill_selection=production, - ) - - # Dropout on the GSP, but not the future GSP - gsp_dropout_time_datapipe = get_t0_datapipe("gsp").select_dropout_time( - # GSP data for time t0 may be missing. Only have value for t0-30mins - dropout_timedeltas=[minutes(-30)], - dropout_frac=0 if production else 0.1, - ) - - datapipes_dict["gsp"] = datapipes_dict["gsp"].apply_dropout_time( - dropout_time_datapipe=gsp_dropout_time_datapipe, - ) - - get_t0_datapipe.close() - - return - - def construct_sliced_data_pipeline( config_filename: str, location_pipe: IterDataPipe, From f23be62db33b9227d55ef8d9a3ebd392765e188e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Nov 2023 16:41:00 +0000 Subject: [PATCH 27/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/training/pvnet.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/ocf_datapipes/training/pvnet.py b/ocf_datapipes/training/pvnet.py index 62fa9e4a9..985384969 100644 --- a/ocf_datapipes/training/pvnet.py +++ b/ocf_datapipes/training/pvnet.py @@ -1,16 +1,12 @@ """Create the training/validation datapipe for training the PVNet Model""" import logging -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple, Union +from datetime import datetime +from typing import Optional -import numpy as np import xarray as xr -from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe from ocf_datapipes.batch import MergeNumpyModalities -from ocf_datapipes.config.model import Configuration -from ocf_datapipes.load import OpenGSPFromDatabase, OpenPVFromPVSitesDB from ocf_datapipes.training.common import ( AddZeroedNWPData, AddZeroedSatelliteData, @@ -29,8 +25,6 @@ NEW_NWP_STD, RSS_MEAN, RSS_STD, - BatchKey, - NumpyBatch, ) xr.set_options(keep_attrs=True) From 5a8e15c252210a5ddfdefab7c3223cbf571bba5f Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 16 Nov 2023 18:47:08 +0000 Subject: [PATCH 28/29] Re-add fix --- ocf_datapipes/training/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index 51b088093..eb358b0d7 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -480,6 +480,8 @@ def construct_loctime_pipelines( key_for_t0="gsp", shuffle=True, nwp_max_dropout_minutes=180, + # Sometimes the forecast is only 4/day so 6 hour intervals - then we add 3-hour dropout + nwp_max_staleness_minutes=60 * 9, ) return location_pipe, t0_datapipe From 62ea906029a74e155d49af0f3c6d092f7f6f792f Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 16 Nov 2023 18:55:21 +0000 Subject: [PATCH 29/29] Add missing function --- ocf_datapipes/training/common.py | 93 ++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index eb358b0d7..3ed962013 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -152,6 +152,99 @@ def open_and_return_datapipes( return used_datapipes +def get_and_return_overlapping_time_periods_and_t0(used_datapipes: dict, key_for_t0: str = "gsp"): + """ + Takes datapipes and obtains the overlapping time periods + t0 time datapipes + + Args: + used_datapipes: Dictionary of datapipes to compute the time intersection of + key_for_t0: Key to use for the t0 datapipe + + Returns: + Dictionary of datapipes with the proper time slices selected + """ + datapipes_for_time_periods = [] # Using later to compute intersections + datapipes_to_return = {} # Returned along with original ones + t0_datapipe = None + configuration = used_datapipes.pop("config") + for key, datapipe in used_datapipes.items(): + if "topo" in key: + continue + if key_for_t0 in key: + forked_datapipes = datapipe.fork(3, buffer_size=100) + t0_datapipe = forked_datapipes[2] + else: + forked_datapipes = datapipe.fork(2, buffer_size=100) + datapipes_to_return[key] = forked_datapipes[0] + if "nwp" == key: + time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods( + sample_period_duration=timedelta(hours=3), # Init times are 3 hours apart + history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), + forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), + time_dim="init_time_utc", + ) + datapipes_for_time_periods.append(time_periods_datapipe) + + if "sat" == key: + time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods( + sample_period_duration=timedelta(minutes=5), + history_duration=timedelta( + minutes=configuration.input_data.satellite.history_minutes + ), + forecast_duration=timedelta(minutes=0), + ) + datapipes_for_time_periods.append(time_periods_datapipe) + + if "hrv" == key: + time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods( + sample_period_duration=timedelta(minutes=5), + history_duration=timedelta( + minutes=configuration.input_data.hrvsatellite.history_minutes + ), + forecast_duration=timedelta(minutes=0), + ) + datapipes_for_time_periods.append(time_periods_datapipe) + + if "pv" == key: + time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods( + sample_period_duration=timedelta(minutes=5), + history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), + forecast_duration=timedelta(minutes=configuration.input_data.pv.forecast_minutes), + ) + datapipes_for_time_periods.append(time_periods_datapipe) + if "gsp" == key: + time_periods_datapipe = forked_datapipes[1].get_contiguous_time_periods( + sample_period_duration=timedelta(minutes=30), + history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), + forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), + ) + datapipes_for_time_periods.append(time_periods_datapipe) + + # Now have the forked ones + # find joint overlapping timer periods + logger.debug("Getting joint time periods") + overlapping_datapipe = datapipes_for_time_periods[0].select_overlapping_time_slice( + secondary_datapipes=datapipes_for_time_periods[1:], + ) + + # select time periods + t0_datapipe = t0_datapipe.select_time_periods(time_periods=overlapping_datapipe) + + num_t0_datapipes = len(datapipes_to_return.keys()) # One for each input + t0_datapipes = t0_datapipe.select_t0_time(return_all_times=False).fork( + num_t0_datapipes, buffer_size=100 + ) + + for i, key in enumerate(list(datapipes_to_return.keys())): + datapipes_to_return[key + "_t0"] = t0_datapipes[i] + + # Re-add config for later + datapipes_to_return["config"] = configuration + if "topo" in used_datapipes.keys(): + datapipes_to_return["topo"] = used_datapipes["topo"] + return datapipes_to_return + + def normalize_gsp(x): """Normalize the GSP data