diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index 7c668f331..3ed962013 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -1,18 +1,25 @@ """Common functionality for datapipes""" import logging -from datetime import timedelta +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Tuple, Union +import numpy as np +import xarray as xr +from torch.utils.data import functional_datapipe from torchdata.datapipes.iter import IterDataPipe from ocf_datapipes.config.model import Configuration from ocf_datapipes.load import ( OpenConfiguration, OpenGSP, + OpenGSPFromDatabase, OpenNWP, OpenPVFromNetCDF, + OpenPVFromPVSitesDB, OpenSatellite, OpenTopography, ) +from ocf_datapipes.utils.consts import BatchKey, NumpyBatch logger = logging.getLogger(__name__) @@ -238,6 +245,570 @@ def get_and_return_overlapping_time_periods_and_t0(used_datapipes: dict, key_for return datapipes_to_return +def normalize_gsp(x): + """Normalize the GSP data + + Args: + x: Input DataArray + + Returns: + Normalized DataArray + """ + return x / x.effective_capacity_mwp + + +def normalize_pv(x): + """Normalize the PV data + + Args: + x: Input DataArray + + Returns: + Normalized DataArray + """ + return (x / x.nominal_capacity_wp).clip(None, 5) + + +def production_sat_scale(x): + """Scale the production satellite data + + Args: + x: Input DataArray + + Returns: + Scaled DataArray + """ + return x / 1024 + + +def concat_xr_time_utc(gsp_dataarrays: List[xr.DataArray]): + """This function is used to combine the split history and future gsp/pv dataarrays. + + These are split inside the `slice_datapipes_by_time()` function below. + + Splitting them inside that function allows us to apply dropout to the + history GSP/PV whilst leaving the future GSP/PV without NaNs. + + We recombine the history and future with this function to allow us to use the + `MergeNumpyModalities()` datapipe without redefining the BatchKeys. + + The `pvnet` model was also written to use a GSP/PV array which has historical and future + and to split it out. These maintains that assumption. + """ + return xr.concat(gsp_dataarrays, dim="time_utc") + + +def gsp_drop_national(x: Union[xr.DataArray, xr.Dataset]): + """Drop entries for national PV output + + Args: + x: Data source of gsp data + + Returns: + Filtered data source + """ + return x.where(x.gsp_id != 0, drop=True) + + +@functional_datapipe("pvnet_select_pv_by_ml_id") +class PVNetSelectPVbyMLIDIterDataPipe(IterDataPipe): + """Select specific set of PV systems by ML ID.""" + + def __init__(self, source_datapipe: IterDataPipe, ml_ids: np.array): + """Select specific set of PV systems by ML ID. + + Args: + source_datapipe: Datapipe emitting PV xarray data + ml_ids: List-like of ML IDs to select + + Returns: + Filtered data source + """ + self.source_datapipe = source_datapipe + self.ml_ids = ml_ids + + def __iter__(self): + for x in self.source_datapipe: + # Check for missing IDs + ml_ids_not_in_data = ~np.isin(self.ml_ids, x.ml_id) + if ml_ids_not_in_data.any(): + missing_ml_ids = np.array(self.ml_ids)[ml_ids_not_in_data] + logger.warning( + f"The following ML IDs were mising in the PV site-level input data: " + f"{missing_ml_ids}. The values for these IDs will be set to NaN." + ) + + x_filtered = ( + # Many ML-IDs are null, so filter first + x.where(~x.ml_id.isnull(), drop=True) + # Swap dimensions so we can select by ml_id coordinate + .swap_dims({"pv_system_id": "ml_id"}) + # Select IDs - missing IDs are given NaN values + .reindex(ml_id=self.ml_ids) + # Swap back dimensions + .swap_dims({"ml_id": "pv_system_id"}) + ) + yield x_filtered + + +def fill_nans_in_pv(x: Union[xr.DataArray, xr.Dataset]): + """Fill NaNs in PV data with the value -1 + + Args: + x: Input DataArray + + Returns: + Normalized DataArray + """ + return x.fillna(-1) + + +def fill_nans_in_arrays(batch: NumpyBatch) -> NumpyBatch: + """Fills all NaN values in each np.ndarray in the batch dictionary with zeros. + + Operation is performed in-place on the batch. + """ + logger.info("Filling Nans with zeros") + for k, v in batch.items(): + if isinstance(v, np.ndarray): + np.nan_to_num(v, copy=False, nan=0.0) + return batch + + +class AddZeroedSatelliteData: + """A callable class used to add zeroed-out satellite data to batches of data. + + This is useful + to speed up batch loading if pre-training the output part of the network without satellite + inputs. + """ + + def __init__(self, configuration: Configuration, is_hrv: bool = False): + """A callable class used to add zeroed-out satellite data to batches of data. + + Args: + configuration: Configuration object + is_hrv: If False, non-HRV data is added by called function, else HRV. + """ + + self.configuration = configuration + self.is_hrv = is_hrv + + def __call__(self, batch: NumpyBatch) -> NumpyBatch: + """Add zeroed-out satellite data to batch with shape accoriding to supplied configuration. + + Batch is modified in-place and returned. + + Args: + batch: Numpy batch of input data. + """ + + variable = "hrvsatellite" if self.is_hrv else "satellite" + + satellite_config = getattr(self.configuration.input_data, variable) + + n_channels = len(getattr(satellite_config, f"{variable}_channels")) + height = getattr(satellite_config, f"{variable}_image_size_pixels_height") + width = getattr(satellite_config, f"{variable}_image_size_pixels_width") + + sequence_len = satellite_config.history_minutes // 5 + 1 - 3 + + batch[getattr(BatchKey, f"{variable}_actual")] = np.zeros( + (sequence_len, n_channels, height, width) + ) + + return batch + + +class AddZeroedNWPData: + """A callable class used to add zeroed-out NWP data to batches of data. + + This is useful to speed up batch loading if pre-training the output part of the network without + NWP inputs. + """ + + def __init__(self, configuration: Configuration): + """A callable class used to add zeroed-out NWP data to batches of data. + + Args: + configuration: Configuration object + """ + self.configuration = configuration + + def __call__(self, batch: NumpyBatch) -> NumpyBatch: + """Add zeroed-out NWP data to batch with shape accoriding to supplied configuration. + + Batch is modified in-place and returned. + + Args: + batch: Numpy batch of input data. + """ + + config = self.configuration.input_data.nwp + + n_channels = len(config.nwp_channels) + height = config.nwp_image_size_pixels_height + width = config.nwp_image_size_pixels_width + + sequence_len = config.history_minutes // 60 + config.forecast_minutes // 60 + 1 + + batch[BatchKey.nwp] = np.zeros((sequence_len, n_channels, height, width)) + + return batch + + +class DatapipeKeyForker: + """ "Internal helper function to track forking of a datapipe.""" + + def __init__(self, keys: List, datapipe: IterDataPipe): + """Internal helper function to track forking of a datapipe. + + As forks are returned, this object tracks the keys left and returns the final copy of the + datapipe when the last key is requested. This makes multiple forking easier and ensures + closure. + + Args: + keys: List of keys for which datapipe duplication is required. + datapipe: Datapipe which will be forked for each ket + """ + self.keys_left = keys + self.datapipe = datapipe + + def __call__(self, key): + """ "Returns a fork of `self.datapipe` and tracks a the keys left to ensure closure. + + Args: + key: key to remove from `self.keys_left`. If `key` is None then an extra copy is made + without affecting `self.keys_left`. + """ + if len(self.keys_left) == 0: + raise ValueError(f"No keys left when requested key : {key}") + if key is not None: + self.keys_left.remove(key) + if len(self.keys_left) > 0: + self.datapipe, return_datapipe = self.datapipe.fork(2, buffer_size=5) + else: + return_datapipe = self.datapipe + return return_datapipe + + def close(self): + """Asserts that the keys have all been used.""" + assert len(self.keys_left) == 0 + + +def _get_datapipes_dict( + config_filename: str, + block_sat: bool, + block_nwp: bool, + production: bool = False, +): + # Load datasets + datapipes_dict = open_and_return_datapipes( + configuration_filename=config_filename, + use_gsp=(not production), + use_pv=(not production), + use_sat=not block_sat, # Only loaded if we aren't replacing them with zeros + use_hrv=False, + use_nwp=not block_nwp, # Only loaded if we aren't replacing them with zeros + use_topo=False, + production=production, + ) + + config: Configuration = datapipes_dict["config"] + + if production: + datapipes_dict["gsp"] = OpenGSPFromDatabase().add_t0_idx_and_sample_period_duration( + sample_period_duration=timedelta(minutes=30), + history_duration=timedelta(minutes=config.input_data.gsp.history_minutes), + ) + if "sat" in datapipes_dict: + datapipes_dict["sat"] = datapipes_dict["sat"].map(production_sat_scale) + if "pv" in datapipes_dict: + datapipes_dict["pv"] = OpenPVFromPVSitesDB(config.input_data.pv.history_minutes) + + if "pv" in datapipes_dict and config.input_data.pv.pv_ml_ids != []: + datapipes_dict["pv"] = datapipes_dict["pv"].pvnet_select_pv_by_ml_id( + config.input_data.pv.pv_ml_ids + ) + + return datapipes_dict + + +def construct_loctime_pipelines( + config_filename: str, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + block_sat: bool = False, + block_nwp: bool = False, +) -> Tuple[IterDataPipe, IterDataPipe]: + """Construct location and time pipelines for the input data config file. + + Args: + config_filename: Path to config file. + start_time: Minimum time for time datapipe. + end_time: Maximum time for time datapipe. + block_sat: Whether to load zeroes for satellite data. + block_nwp: Whether to load zeroes for NWP data. + """ + + datapipes_dict = _get_datapipes_dict( + config_filename, + block_sat, + block_nwp, + ) + + # Pull out config file + config = datapipes_dict.pop("config") + + # We sample time and space of other data using GSP time and space coordinates, so filter GSP + # data first amd this is carried through + datapipes_dict["gsp"] = datapipes_dict["gsp"].map(gsp_drop_national) + if (start_time is not None) or (end_time is not None): + datapipes_dict["gsp"] = datapipes_dict["gsp"].select_train_test_time(start_time, end_time) + + # Get overlapping time periods + location_pipe, t0_datapipe = create_t0_and_loc_datapipes( + datapipes_dict, + configuration=config, + key_for_t0="gsp", + shuffle=True, + nwp_max_dropout_minutes=180, + # Sometimes the forecast is only 4/day so 6 hour intervals - then we add 3-hour dropout + nwp_max_staleness_minutes=60 * 9, + ) + + return location_pipe, t0_datapipe + + +def minutes(num_mins: int): + """Timedelta of a number of minutes. + + Args: + num_mins: Minutes timedelta. + """ + return timedelta(minutes=num_mins) + + +def slice_datapipes_by_time( + datapipes_dict: Dict, + t0_datapipe: IterDataPipe, + configuration: Configuration, + production: bool = False, +) -> None: + """ + Modifies a dictionary of datapipes in-place to yield samples for given times t0. + + The NWP data* will be at least 90 minutes stale (i.e. as if it takes 90 minutes for the foreast + to become available). + + The satellite data* is shaped so that the most recent can be 15 minutes before t0. However, 50% + of the time dropout is applied so that the most recent field is between 45 and 20 minutes before + t0. When dropped out like this, the values after this selected dropout time are set to NaN. + + The HRV data* is similar to the satellite data and if both are included they drop out + simulataneously. + + The GSP data is split into "gsp" and "gsp_future" keys. 10% of the time the gsp value for time + t0, which occurs under the "gsp" key, is set to NaN + + The PV data* is also split it "pv" and "pv_future" keys. + + * if included + + n.b. PV and HRV are included in this function, but not yet in the rest of the pvnet pipeline. + This is mostly for demonstratio purposes of how the dropout might be applied. + + Args: + datapipes_dict: Dictionary of used datapipes and t0 ones + t0_datapipe: Datapipe which yields t0 times for sample + configuration: Configuration object. + production: Whether constucting pipeline for production inference. No dropout is used if + True. + + """ + + conf_in = configuration.input_data + + # Use DatapipeKeyForker to avoid forking t0_datapipe too many times, or leaving any forks unused + fork_keys = {k for k in datapipes_dict.keys() if k not in ["topo"]} + get_t0_datapipe = DatapipeKeyForker(fork_keys, t0_datapipe) + + sat_and_hrv_dropout_kwargs = dict( + # Satellite is either 30 minutes or 60 minutes delayed in production. Match during training + dropout_timedeltas=[minutes(-60), minutes(-30)], + dropout_frac=0 if production else 1.0, + ) + + sat_delay = minutes(-configuration.input_data.satellite.live_delay_minutes) + + if "nwp" in datapipes_dict: + datapipes_dict["nwp"] = datapipes_dict["nwp"].convert_to_nwp_target_time_with_dropout( + t0_datapipe=get_t0_datapipe("nwp"), + sample_period_duration=minutes(60), + history_duration=minutes(conf_in.nwp.history_minutes), + forecast_duration=minutes(conf_in.nwp.forecast_minutes), + # The NWP forecast will always be at least 180 minutes stale + dropout_timedeltas=[minutes(-180)], + dropout_frac=0 if production else 1.0, + ) + + if "sat" in datapipes_dict: + # Take time slices of sat data + datapipes_dict["sat"] = datapipes_dict["sat"].select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(5), + interval_start=minutes(-conf_in.satellite.history_minutes), + interval_end=sat_delay, + fill_selection=production, + max_steps_gap=2, + ) + + # Generate randomly sampled dropout times + sat_dropout_time_datapipe = get_t0_datapipe("sat").select_dropout_time( + **sat_and_hrv_dropout_kwargs + ) + + if "hrv" in datapipes_dict: + # Make dropout-time copy for hrv if included in data. + # HRV and non-HRV will dropout simultaneously. + sat_dropout_time_datapipe, hrv_dropout_time_datapipe = sat_dropout_time_datapipe.fork( + 2, buffer_size=5 + ) + + # Apply the dropout + datapipes_dict["sat"] = datapipes_dict["sat"].apply_dropout_time( + dropout_time_datapipe=sat_dropout_time_datapipe, + ) + + if "hrv" in datapipes_dict: + if "sat" not in datapipes_dict: + # Generate randomly sampled dropout times + # This is shared with sat if sat included + hrv_dropout_time_datapipe = get_t0_datapipe(None).select_dropout_time( + **sat_and_hrv_dropout_kwargs + ) + + datapipes_dict["hrv"] = datapipes_dict["hrv"].select_time_slice( + t0_datapipe=get_t0_datapipe("hrv"), + sample_period_duration=minutes(5), + interval_start=minutes(-conf_in.hrvsatellite.history_minutes), + interval_end=sat_delay, + fill_selection=production, + max_steps_gap=2, + ) + + # Apply the dropout + datapipes_dict["hrv"] = datapipes_dict["hrv"].apply_dropout_time( + dropout_time_datapipe=hrv_dropout_time_datapipe, + ) + + if "pv" in datapipes_dict: + datapipes_dict["pv"], dp = datapipes_dict["pv"].fork(2, buffer_size=5) + + datapipes_dict["pv_future"] = dp.select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(5), + interval_start=minutes(5), + interval_end=minutes(conf_in.pv.forecast_minutes), + fill_selection=production, + ) + + datapipes_dict["pv"] = datapipes_dict["pv"].select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(5), + interval_start=minutes(-conf_in.pv.history_minutes), + interval_end=minutes(0), + fill_selection=production, + ) + + # Dropout on the PV, but not the future PV + pv_dropout_time_datapipe = get_t0_datapipe("pv").select_dropout_time( + # All PV data could be delayed by up to 30 minutes + # (this does not stem from production - just setting for now) + dropout_timedeltas=[minutes(m) for m in range(-30, 0, 5)], + dropout_frac=0.1 if production else 1, + ) + + datapipes_dict["pv"] = datapipes_dict["pv"].apply_dropout_time( + dropout_time_datapipe=pv_dropout_time_datapipe, + ) + + # Apply extra PV dropout using different delays per system and droping out entire PV systems + # independently + if not production: + datapipes_dict["pv"].apply_pv_dropout( + system_dropout_fractions=np.linspace(0, 0.2, 100), + system_dropout_timedeltas=[minutes(m) for m in [-15, -10, -5, 0]], + ) + + if "gsp" in datapipes_dict: + datapipes_dict["gsp"], dp = datapipes_dict["gsp"].fork(2, buffer_size=5) + + datapipes_dict["gsp_future"] = dp.select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(30), + interval_start=minutes(30), + interval_end=minutes(conf_in.gsp.forecast_minutes), + fill_selection=production, + ) + + datapipes_dict["gsp"] = datapipes_dict["gsp"].select_time_slice( + t0_datapipe=get_t0_datapipe(None), + sample_period_duration=minutes(30), + interval_start=-minutes(conf_in.gsp.history_minutes), + interval_end=minutes(0), + fill_selection=production, + ) + + # Dropout on the GSP, but not the future GSP + gsp_dropout_time_datapipe = get_t0_datapipe("gsp").select_dropout_time( + # GSP data for time t0 may be missing. Only have value for t0-30mins + dropout_timedeltas=[minutes(-30)], + dropout_frac=0 if production else 0.1, + ) + + datapipes_dict["gsp"] = datapipes_dict["gsp"].apply_dropout_time( + dropout_time_datapipe=gsp_dropout_time_datapipe, + ) + + get_t0_datapipe.close() + + return + + +def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: + """ + Check if there are any Nans values in the satellite data. + """ + if np.any(np.isnan(batch[BatchKey.satellite_actual])): + logger.error("Found nans values in satellite data") + + logger.error(batch[BatchKey.satellite_actual].shape) + + # loop over time and channels + for dim in [0, 1]: + for t in range(batch[BatchKey.satellite_actual].shape[dim]): + if dim == 0: + sate_data_one_step = batch[BatchKey.satellite_actual][t] + else: + sate_data_one_step = batch[BatchKey.satellite_actual][:, t] + nans = np.isnan(sate_data_one_step) + + if np.any(nans): + percent_nans = np.sum(nans) / np.prod(sate_data_one_step.shape) * 100 + + logger.error( + f"Found nans values in satellite data at index {t} ({dim=}). " + f"{percent_nans}% of values are nans" + ) + else: + logger.error(f"Found no nans values in satellite data at index {t} {dim=}") + + raise ValueError("Found nans values in satellite data") + + return batch + + def add_selected_time_slices_from_datapipes(used_datapipes: dict): """ Takes datapipes and t0 datapipes and returns the sliced datapipes diff --git a/ocf_datapipes/training/pvnet.py b/ocf_datapipes/training/pvnet.py index 7f8788481..985384969 100644 --- a/ocf_datapipes/training/pvnet.py +++ b/ocf_datapipes/training/pvnet.py @@ -1,564 +1,36 @@ """Create the training/validation datapipe for training the PVNet Model""" import logging -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple, Union +from datetime import datetime +from typing import Optional -import numpy as np import xarray as xr -from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe from ocf_datapipes.batch import MergeNumpyModalities -from ocf_datapipes.config.model import Configuration -from ocf_datapipes.load import OpenGSPFromDatabase, OpenPVFromPVSitesDB from ocf_datapipes.training.common import ( - create_t0_and_loc_datapipes, - open_and_return_datapipes, + AddZeroedNWPData, + AddZeroedSatelliteData, + _get_datapipes_dict, + check_nans_in_satellite_data, + concat_xr_time_utc, + construct_loctime_pipelines, + fill_nans_in_arrays, + fill_nans_in_pv, + normalize_gsp, + normalize_pv, + slice_datapipes_by_time, ) from ocf_datapipes.utils.consts import ( NEW_NWP_MEAN, NEW_NWP_STD, RSS_MEAN, RSS_STD, - BatchKey, - NumpyBatch, ) xr.set_options(keep_attrs=True) logger = logging.getLogger("pvnet_datapipe") -def normalize_gsp(x): - """Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.effective_capacity_mwp - - -def normalize_pv(x): - """Normalize the PV data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return (x / x.nominal_capacity_wp).clip(None, 5) - - -def production_sat_scale(x): - """Scale the production satellite data - - Args: - x: Input DataArray - - Returns: - Scaled DataArray - """ - return x / 1024 - - -def concat_xr_time_utc(gsp_dataarrays: List[xr.DataArray]): - """This function is used to combine the split history and future gsp/pv dataarrays. - - These are split inside the `slice_datapipes_by_time()` function below. - - Splitting them inside that function allows us to apply dropout to the - history GSP/PV whilst leaving the future GSP/PV without NaNs. - - We recombine the history and future with this function to allow us to use the - `MergeNumpyModalities()` datapipe without redefining the BatchKeys. - - The `pvnet` model was also written to use a GSP/PV array which has historical and future - and to split it out. These maintains that assumption. - """ - return xr.concat(gsp_dataarrays, dim="time_utc") - - -def gsp_drop_national(x: Union[xr.DataArray, xr.Dataset]): - """Drop entries for national PV output - - Args: - x: Data source of gsp data - - Returns: - Filtered data source - """ - return x.where(x.gsp_id != 0, drop=True) - - -@functional_datapipe("pvnet_select_pv_by_ml_id") -class PVNetSelectPVbyMLIDIterDataPipe(IterDataPipe): - """Select specific set of PV systems by ML ID.""" - - def __init__(self, source_datapipe: IterDataPipe, ml_ids: np.array): - """Select specific set of PV systems by ML ID. - - Args: - source_datapipe: Datapipe emitting PV xarray data - ml_ids: List-like of ML IDs to select - - Returns: - Filtered data source - """ - self.source_datapipe = source_datapipe - self.ml_ids = ml_ids - - def __iter__(self): - for x in self.source_datapipe: - # Check for missing IDs - ml_ids_not_in_data = ~np.isin(self.ml_ids, x.ml_id) - if ml_ids_not_in_data.any(): - missing_ml_ids = np.array(self.ml_ids)[ml_ids_not_in_data] - logger.warning( - f"The following ML IDs were mising in the PV site-level input data: " - f"{missing_ml_ids}. The values for these IDs will be set to NaN." - ) - - x_filtered = ( - # Many ML-IDs are null, so filter first - x.where(~x.ml_id.isnull(), drop=True) - # Swap dimensions so we can select by ml_id coordinate - .swap_dims({"pv_system_id": "ml_id"}) - # Select IDs - missing IDs are given NaN values - .reindex(ml_id=self.ml_ids) - # Swap back dimensions - .swap_dims({"ml_id": "pv_system_id"}) - ) - yield x_filtered - - -def fill_nans_in_pv(x: Union[xr.DataArray, xr.Dataset]): - """Fill NaNs in PV data with the value -1 - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x.fillna(-1) - - -def fill_nans_in_arrays(batch: NumpyBatch) -> NumpyBatch: - """Fills all NaN values in each np.ndarray in the batch dictionary with zeros. - - Operation is performed in-place on the batch. - """ - logger.info("Filling Nans with zeros") - for k, v in batch.items(): - if isinstance(v, np.ndarray): - np.nan_to_num(v, copy=False, nan=0.0) - return batch - - -class AddZeroedSatelliteData: - """A callable class used to add zeroed-out satellite data to batches of data. - - This is useful - to speed up batch loading if pre-training the output part of the network without satellite - inputs. - """ - - def __init__(self, configuration: Configuration, is_hrv: bool = False): - """A callable class used to add zeroed-out satellite data to batches of data. - - Args: - configuration: Configuration object - is_hrv: If False, non-HRV data is added by called function, else HRV. - """ - - self.configuration = configuration - self.is_hrv = is_hrv - - def __call__(self, batch: NumpyBatch) -> NumpyBatch: - """Add zeroed-out satellite data to batch with shape accoriding to supplied configuration. - - Batch is modified in-place and returned. - - Args: - batch: Numpy batch of input data. - """ - - variable = "hrvsatellite" if self.is_hrv else "satellite" - - satellite_config = getattr(self.configuration.input_data, variable) - - n_channels = len(getattr(satellite_config, f"{variable}_channels")) - height = getattr(satellite_config, f"{variable}_image_size_pixels_height") - width = getattr(satellite_config, f"{variable}_image_size_pixels_width") - - sequence_len = satellite_config.history_minutes // 5 + 1 - 3 - - batch[getattr(BatchKey, f"{variable}_actual")] = np.zeros( - (sequence_len, n_channels, height, width) - ) - - return batch - - -class AddZeroedNWPData: - """A callable class used to add zeroed-out NWP data to batches of data. - - This is useful to speed up batch loading if pre-training the output part of the network without - NWP inputs. - """ - - def __init__(self, configuration: Configuration): - """A callable class used to add zeroed-out NWP data to batches of data. - - Args: - configuration: Configuration object - """ - self.configuration = configuration - - def __call__(self, batch: NumpyBatch) -> NumpyBatch: - """Add zeroed-out NWP data to batch with shape accoriding to supplied configuration. - - Batch is modified in-place and returned. - - Args: - batch: Numpy batch of input data. - """ - - config = self.configuration.input_data.nwp - - n_channels = len(config.nwp_channels) - height = config.nwp_image_size_pixels_height - width = config.nwp_image_size_pixels_width - - sequence_len = config.history_minutes // 60 + config.forecast_minutes // 60 + 1 - - batch[BatchKey.nwp] = np.zeros((sequence_len, n_channels, height, width)) - - return batch - - -class DatapipeKeyForker: - """ "Internal helper function to track forking of a datapipe.""" - - def __init__(self, keys: List, datapipe: IterDataPipe): - """Internal helper function to track forking of a datapipe. - - As forks are returned, this object tracks the keys left and returns the final copy of the - datapipe when the last key is requested. This makes multiple forking easier and ensures - closure. - - Args: - keys: List of keys for which datapipe duplication is required. - datapipe: Datapipe which will be forked for each ket - """ - self.keys_left = keys - self.datapipe = datapipe - - def __call__(self, key): - """ "Returns a fork of `self.datapipe` and tracks a the keys left to ensure closure. - - Args: - key: key to remove from `self.keys_left`. If `key` is None then an extra copy is made - without affecting `self.keys_left`. - """ - if len(self.keys_left) == 0: - raise ValueError(f"No keys left when requested key : {key}") - if key is not None: - self.keys_left.remove(key) - if len(self.keys_left) > 0: - self.datapipe, return_datapipe = self.datapipe.fork(2, buffer_size=5) - else: - return_datapipe = self.datapipe - return return_datapipe - - def close(self): - """Asserts that the keys have all been used.""" - assert len(self.keys_left) == 0 - - -def _get_datapipes_dict( - config_filename: str, - block_sat: bool, - block_nwp: bool, - production: bool = False, -): - # Load datasets - datapipes_dict = open_and_return_datapipes( - configuration_filename=config_filename, - use_gsp=(not production), - use_pv=(not production), - use_sat=not block_sat, # Only loaded if we aren't replacing them with zeros - use_hrv=False, - use_nwp=not block_nwp, # Only loaded if we aren't replacing them with zeros - use_topo=False, - production=production, - ) - - config: Configuration = datapipes_dict["config"] - - if production: - datapipes_dict["gsp"] = OpenGSPFromDatabase().add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=config.input_data.gsp.history_minutes), - ) - if "sat" in datapipes_dict: - datapipes_dict["sat"] = datapipes_dict["sat"].map(production_sat_scale) - if "pv" in datapipes_dict: - datapipes_dict["pv"] = OpenPVFromPVSitesDB(config.input_data.pv.history_minutes) - - if "pv" in datapipes_dict and config.input_data.pv.pv_ml_ids != []: - datapipes_dict["pv"] = datapipes_dict["pv"].pvnet_select_pv_by_ml_id( - config.input_data.pv.pv_ml_ids - ) - - return datapipes_dict - - -def construct_loctime_pipelines( - config_filename: str, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, - block_sat: bool = False, - block_nwp: bool = False, -) -> Tuple[IterDataPipe, IterDataPipe]: - """Construct location and time pipelines for the input data config file. - - Args: - config_filename: Path to config file. - start_time: Minimum time for time datapipe. - end_time: Maximum time for time datapipe. - block_sat: Whether to load zeroes for satellite data. - block_nwp: Whether to load zeroes for NWP data. - """ - - datapipes_dict = _get_datapipes_dict( - config_filename, - block_sat, - block_nwp, - ) - - # Pull out config file - config = datapipes_dict.pop("config") - - # We sample time and space of other data using GSP time and space coordinates, so filter GSP - # data first amd this is carried through - datapipes_dict["gsp"] = datapipes_dict["gsp"].map(gsp_drop_national) - if (start_time is not None) or (end_time is not None): - datapipes_dict["gsp"] = datapipes_dict["gsp"].select_train_test_time(start_time, end_time) - - # Get overlapping time periods - location_pipe, t0_datapipe = create_t0_and_loc_datapipes( - datapipes_dict, - configuration=config, - key_for_t0="gsp", - shuffle=True, - nwp_max_dropout_minutes=180, - # Sometimes the forecast is only 4/day so 6 hour intervals - then we add 3-hour dropout - nwp_max_staleness_minutes=60 * 9, - ) - - return location_pipe, t0_datapipe - - -def minutes(num_mins: int): - """Timedelta of a number of minutes. - - Args: - num_mins: Minutes timedelta. - """ - return timedelta(minutes=num_mins) - - -def slice_datapipes_by_time( - datapipes_dict: Dict, - t0_datapipe: IterDataPipe, - configuration: Configuration, - production: bool = False, -) -> None: - """ - Modifies a dictionary of datapipes in-place to yield samples for given times t0. - - The NWP data* will be at least 90 minutes stale (i.e. as if it takes 90 minutes for the foreast - to become available). - - The satellite data* is shaped so that the most recent can be 15 minutes before t0. However, 50% - of the time dropout is applied so that the most recent field is between 45 and 20 minutes before - t0. When dropped out like this, the values after this selected dropout time are set to NaN. - - The HRV data* is similar to the satellite data and if both are included they drop out - simulataneously. - - The GSP data is split into "gsp" and "gsp_future" keys. 10% of the time the gsp value for time - t0, which occurs under the "gsp" key, is set to NaN - - The PV data* is also split it "pv" and "pv_future" keys. - - * if included - - n.b. PV and HRV are included in this function, but not yet in the rest of the pvnet pipeline. - This is mostly for demonstratio purposes of how the dropout might be applied. - - Args: - datapipes_dict: Dictionary of used datapipes and t0 ones - t0_datapipe: Datapipe which yields t0 times for sample - configuration: Configuration object. - production: Whether constucting pipeline for production inference. No dropout is used if - True. - - """ - - conf_in = configuration.input_data - - # Use DatapipeKeyForker to avoid forking t0_datapipe too many times, or leaving any forks unused - fork_keys = {k for k in datapipes_dict.keys() if k not in ["topo"]} - get_t0_datapipe = DatapipeKeyForker(fork_keys, t0_datapipe) - - sat_and_hrv_dropout_kwargs = dict( - # Satellite is either 30 minutes or 60 minutes delayed in production. Match during training - dropout_timedeltas=[minutes(-60), minutes(-30)], - dropout_frac=0 if production else 1.0, - ) - - sat_delay = minutes(-configuration.input_data.satellite.live_delay_minutes) - - if "nwp" in datapipes_dict: - datapipes_dict["nwp"] = datapipes_dict["nwp"].convert_to_nwp_target_time_with_dropout( - t0_datapipe=get_t0_datapipe("nwp"), - sample_period_duration=minutes(60), - history_duration=minutes(conf_in.nwp.history_minutes), - forecast_duration=minutes(conf_in.nwp.forecast_minutes), - # The NWP forecast will always be at least 180 minutes stale - dropout_timedeltas=[minutes(-180)], - dropout_frac=0 if production else 1.0, - ) - - if "sat" in datapipes_dict: - # Take time slices of sat data - datapipes_dict["sat"] = datapipes_dict["sat"].select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(5), - interval_start=minutes(-conf_in.satellite.history_minutes), - interval_end=sat_delay, - fill_selection=production, - max_steps_gap=2, - ) - - # Generate randomly sampled dropout times - sat_dropout_time_datapipe = get_t0_datapipe("sat").select_dropout_time( - **sat_and_hrv_dropout_kwargs - ) - - if "hrv" in datapipes_dict: - # Make dropout-time copy for hrv if included in data. - # HRV and non-HRV will dropout simultaneously. - sat_dropout_time_datapipe, hrv_dropout_time_datapipe = sat_dropout_time_datapipe.fork( - 2, buffer_size=5 - ) - - # Apply the dropout - datapipes_dict["sat"] = datapipes_dict["sat"].apply_dropout_time( - dropout_time_datapipe=sat_dropout_time_datapipe, - ) - - if "hrv" in datapipes_dict: - if "sat" not in datapipes_dict: - # Generate randomly sampled dropout times - # This is shared with sat if sat included - hrv_dropout_time_datapipe = get_t0_datapipe(None).select_dropout_time( - **sat_and_hrv_dropout_kwargs - ) - - datapipes_dict["hrv"] = datapipes_dict["hrv"].select_time_slice( - t0_datapipe=get_t0_datapipe("hrv"), - sample_period_duration=minutes(5), - interval_start=minutes(-conf_in.hrvsatellite.history_minutes), - interval_end=sat_delay, - fill_selection=production, - max_steps_gap=2, - ) - - # Apply the dropout - datapipes_dict["hrv"] = datapipes_dict["hrv"].apply_dropout_time( - dropout_time_datapipe=hrv_dropout_time_datapipe, - ) - - if "pv" in datapipes_dict: - datapipes_dict["pv"], dp = datapipes_dict["pv"].fork(2, buffer_size=5) - - datapipes_dict["pv_future"] = dp.select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(5), - interval_start=minutes(5), - interval_end=minutes(conf_in.pv.forecast_minutes), - fill_selection=production, - ) - - datapipes_dict["pv"] = datapipes_dict["pv"].select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(5), - interval_start=minutes(-conf_in.pv.history_minutes), - interval_end=minutes(0), - fill_selection=production, - ) - - # Dropout on the PV, but not the future PV - pv_dropout_time_datapipe = get_t0_datapipe("pv").select_dropout_time( - # All PV data could be delayed by up to 30 minutes - # (this does not stem from production - just setting for now) - dropout_timedeltas=[minutes(m) for m in range(-30, 0, 5)], - dropout_frac=0.1 if production else 1, - ) - - datapipes_dict["pv"] = datapipes_dict["pv"].apply_dropout_time( - dropout_time_datapipe=pv_dropout_time_datapipe, - ) - - # Apply extra PV dropout using different delays per system and droping out entire PV systems - # independently - if not production: - datapipes_dict["pv"].apply_pv_dropout( - system_dropout_fractions=np.linspace(0, 0.2, 100), - system_dropout_timedeltas=[minutes(m) for m in [-15, -10, -5, 0]], - ) - - if "gsp" in datapipes_dict: - datapipes_dict["gsp"], dp = datapipes_dict["gsp"].fork(2, buffer_size=5) - - datapipes_dict["gsp_future"] = dp.select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(30), - interval_start=minutes(30), - interval_end=minutes(conf_in.gsp.forecast_minutes), - fill_selection=production, - ) - - datapipes_dict["gsp"] = datapipes_dict["gsp"].select_time_slice( - t0_datapipe=get_t0_datapipe(None), - sample_period_duration=minutes(30), - interval_start=-minutes(conf_in.gsp.history_minutes), - interval_end=minutes(0), - fill_selection=production, - ) - - # Dropout on the GSP, but not the future GSP - gsp_dropout_time_datapipe = get_t0_datapipe("gsp").select_dropout_time( - # GSP data for time t0 may be missing. Only have value for t0-30mins - dropout_timedeltas=[minutes(-30)], - dropout_frac=0 if production else 0.1, - ) - - datapipes_dict["gsp"] = datapipes_dict["gsp"].apply_dropout_time( - dropout_time_datapipe=gsp_dropout_time_datapipe, - ) - - get_t0_datapipe.close() - - return - - def construct_sliced_data_pipeline( config_filename: str, location_pipe: IterDataPipe, @@ -725,36 +197,3 @@ def pvnet_datapipe( ) return datapipe - - -def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: - """ - Check if there are any Nans values in the satellite data. - """ - if np.any(np.isnan(batch[BatchKey.satellite_actual])): - logger.error("Found nans values in satellite data") - - logger.error(batch[BatchKey.satellite_actual].shape) - - # loop over time and channels - for dim in [0, 1]: - for t in range(batch[BatchKey.satellite_actual].shape[dim]): - if dim == 0: - sate_data_one_step = batch[BatchKey.satellite_actual][t] - else: - sate_data_one_step = batch[BatchKey.satellite_actual][:, t] - nans = np.isnan(sate_data_one_step) - - if np.any(nans): - percent_nans = np.sum(nans) / np.prod(sate_data_one_step.shape) * 100 - - logger.error( - f"Found nans values in satellite data at index {t} ({dim=}). " - f"{percent_nans}% of values are nans" - ) - else: - logger.error(f"Found no nans values in satellite data at index {t} {dim=}") - - raise ValueError("Found nans values in satellite data") - - return batch diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py new file mode 100644 index 000000000..89e0c07b8 --- /dev/null +++ b/ocf_datapipes/training/windnet.py @@ -0,0 +1,385 @@ +"""Create the training/validation datapipe for training the PVNet Model""" +import logging +from datetime import datetime, timedelta +from typing import List, Optional, Tuple, Union + +import xarray as xr +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterableWrapper, IterDataPipe + +from ocf_datapipes.batch import MergeNumpyModalities +from ocf_datapipes.config.model import Configuration +from ocf_datapipes.load import ( + OpenConfiguration, +) +from ocf_datapipes.training.common import ( + AddZeroedNWPData, + AddZeroedSatelliteData, + _get_datapipes_dict, + check_nans_in_satellite_data, + concat_xr_time_utc, + construct_loctime_pipelines, + fill_nans_in_arrays, + fill_nans_in_pv, + normalize_gsp, + normalize_pv, + slice_datapipes_by_time, +) +from ocf_datapipes.utils.consts import ( + NEW_NWP_MEAN, + NEW_NWP_STD, + RSS_MEAN, + RSS_STD, +) +from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset + +xr.set_options(keep_attrs=True) +logger = logging.getLogger("windnet_datapipe") + + +def scale_wind_speed_to_power(x: Union[xr.DataArray, xr.Dataset]): + """ + Scale wind speed to power to estimate the generation of wind power from ground sensors + + Roughly, double speed in m/s, and convert with the power scale + + Args: + x: xr.DataArray or xr.Dataset containing wind speed + + Returns: + Rescaled wind speed to MWh roughly + """ + # Convert knots to m/s + x = x * 0.514444 + # Roughly double speed to get power + x = x * 2 + return x + + +@functional_datapipe("dict_datasets") +class DictDatasetIterDataPipe(IterDataPipe): + """Create a dictionary of xr.Datasets from a set of iterators""" + + datapipes: Tuple[IterDataPipe] + length: Optional[int] + + def __init__(self, *datapipes: IterDataPipe, keys: List[str]): + """Init""" + if not all(isinstance(dp, IterDataPipe) for dp in datapipes): + raise TypeError( + "All inputs are required to be `IterDataPipe` " "for `ZipIterDataPipe`." + ) + super().__init__() + self.keys = keys + self.datapipes = datapipes # type: ignore[assignment] + self.length = None + assert len(self.keys) == len(self.datapipes), "Number of keys must match number of pipes" + + def __iter__(self): + """Iter""" + iterators = [iter(datapipe) for datapipe in self.datapipes] + for data in zip(*iterators): + # Yield a dictionary of the data, using the keys in self.keys + yield {k: v for k, v in zip(self.keys, data)} + + +@functional_datapipe("load_dict_datasets") +class LoadDictDatasetIterDataPipe(IterDataPipe): + """Load NetCDF files and split them back into individual xr.Datasets""" + + filenames: List[str] + keys: List[str] + + def __init__(self, filenames: List[str], keys: List[str]): + """ + Load NetCDF files and split them back into individual xr.Datasets + + Args: + filenames: List of filesnames to load + keys: List of keys from each file to use, each key should be a + dataarray in the xr.Dataset + """ + super().__init__() + self.keys = keys + self.filenames = filenames + + def __iter__(self): + """Iterate through each filename, loading it, uncombining it, and then yielding it""" + while True: + for filename in self.filenames: + dataset = xr.open_dataset(filename) + datasets = uncombine_from_single_dataset(dataset) + # Yield a dictionary of the data, using the keys in self.keys + dataset_dict = {} + for k in self.keys: + dataset_dict[k] = datasets[k] + yield dataset_dict + + +@functional_datapipe("convert_to_numpy_batch") +class ConvertToNumpyBatchIterDataPipe(IterDataPipe): + """Converts Xarray Dataset to Numpy Batch""" + + def __init__( + self, + dataset_dict_dp: IterDataPipe, + configuration: Configuration, + block_sat: bool = False, + block_nwp: bool = False, + check_satellite_no_zeros: bool = False, + ): + """Init""" + super().__init__() + self.dataset_dict_dp = dataset_dict_dp + self.configuration = configuration + self.block_sat = block_sat + self.block_nwp = block_nwp + self.check_satellite_no_zeros = check_satellite_no_zeros + + def __iter__(self): + """Iter""" + for datapipes_dict in self.dataset_dict_dp: + # Spatially slice, normalize, and convert data to numpy arrays + numpy_modalities = [] + # Unpack for convenience + conf_sat = self.configuration.input_data.satellite + conf_nwp = self.configuration.input_data.nwp + if "nwp" in datapipes_dict: + numpy_modalities.append(datapipes_dict["nwp"].convert_nwp_to_numpy_batch()) + if "sat" in datapipes_dict: + numpy_modalities.append(datapipes_dict["sat"].convert_satellite_to_numpy_batch()) + if "pv" in datapipes_dict: + numpy_modalities.append(datapipes_dict["pv"].convert_pv_to_numpy_batch()) + numpy_modalities.append(datapipes_dict["gsp"].convert_gsp_to_numpy_batch()) + + logger.debug("Combine all the data sources") + combined_datapipe = MergeNumpyModalities(numpy_modalities).add_sun_position( + modality_name="gsp" + ) + + if self.block_sat and conf_sat != "": + sat_block_func = AddZeroedSatelliteData(self.configuration) + combined_datapipe = combined_datapipe.map(sat_block_func) + + if self.block_nwp and conf_nwp != "": + nwp_block_func = AddZeroedNWPData(self.configuration) + combined_datapipe = combined_datapipe.map(nwp_block_func) + + logger.info("Filtering out samples with no data") + if self.check_satellite_no_zeros: + # in production we don't want any nans in the satellite data + combined_datapipe = combined_datapipe.map(check_nans_in_satellite_data) + + combined_datapipe = combined_datapipe.map(fill_nans_in_arrays) + + yield next(iter(combined_datapipe)) + + +def minutes(num_mins: int): + """Timedelta of a number of minutes. + + Args: + num_mins: Minutes timedelta. + """ + return timedelta(minutes=num_mins) + + +def construct_sliced_data_pipeline( + config_filename: str, + location_pipe: IterDataPipe, + t0_datapipe: IterDataPipe, + block_sat: bool = False, + block_nwp: bool = False, + production: bool = False, +) -> dict: + """Constructs data pipeline for the input data config file. + + This yields samples from the location and time datapipes. + + Args: + config_filename: Path to config file. + location_pipe: Datapipe yielding locations. + t0_datapipe: Datapipe yielding times. + block_sat: Whether to load zeroes for satellite data. + block_nwp: Whether to load zeroes for NWP data. + production: Whether constucting pipeline for production inference. + """ + + assert not (production and (block_sat or block_nwp)) + + datapipes_dict = _get_datapipes_dict( + config_filename, + block_sat, + block_nwp, + production=production, + ) + + configuration = datapipes_dict.pop("config") + + # Unpack for convenience + conf_sat = configuration.input_data.satellite + conf_nwp = configuration.input_data.nwp + + # Slice all of the datasets by time - this is an in-place operation + slice_datapipes_by_time(datapipes_dict, t0_datapipe, configuration, production) + + if "nwp" in datapipes_dict: + nwp_datapipe = datapipes_dict["nwp"] + + location_pipe, location_pipe_copy = location_pipe.fork(2, buffer_size=5) + nwp_datapipe = nwp_datapipe.select_spatial_slice_pixels( + location_pipe_copy, + roi_height_pixels=conf_nwp.nwp_image_size_pixels_height, + roi_width_pixels=conf_nwp.nwp_image_size_pixels_width, + ) + nwp_datapipe = nwp_datapipe.normalize(mean=NEW_NWP_MEAN, std=NEW_NWP_STD) + + if "sat" in datapipes_dict: + sat_datapipe = datapipes_dict["sat"] + + location_pipe, location_pipe_copy = location_pipe.fork(2, buffer_size=5) + sat_datapipe = sat_datapipe.select_spatial_slice_pixels( + location_pipe_copy, + roi_height_pixels=conf_sat.satellite_image_size_pixels_height, + roi_width_pixels=conf_sat.satellite_image_size_pixels_width, + ) + sat_datapipe = sat_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD) + + if "pv" in datapipes_dict: + # Recombine PV arrays - see function doc for further explanation + pv_datapipe = ( + datapipes_dict["pv"].zip_ocf(datapipes_dict["pv_future"]).map(concat_xr_time_utc) + ) + pv_datapipe = pv_datapipe.normalize(normalize_fn=normalize_pv) + pv_datapipe = pv_datapipe.map(fill_nans_in_pv) + + # GSP always assumed to be in data + location_pipe, location_pipe_copy = location_pipe.fork(2, buffer_size=5) + gsp_future_datapipe = datapipes_dict["gsp_future"] + gsp_future_datapipe = gsp_future_datapipe.select_spatial_slice_meters( + location_datapipe=location_pipe_copy, + roi_height_meters=1, + roi_width_meters=1, + dim_name="gsp_id", + ) + + gsp_datapipe = datapipes_dict["gsp"] + gsp_datapipe = gsp_datapipe.select_spatial_slice_meters( + location_datapipe=location_pipe, + roi_height_meters=1, + roi_width_meters=1, + dim_name="gsp_id", + ) + + # Recombine GSP arrays - see function doc for further explanation + gsp_datapipe = gsp_datapipe.zip_ocf(gsp_future_datapipe).map(concat_xr_time_utc) + gsp_datapipe = gsp_datapipe.normalize(normalize_fn=normalize_gsp) + + finished_dataset_dict = {"gsp": gsp_datapipe, "config": configuration} + if "nwp" in datapipes_dict: + finished_dataset_dict["nwp"] = nwp_datapipe + if "sat" in datapipes_dict: + finished_dataset_dict["sat"] = sat_datapipe + if "pv" in datapipes_dict: + finished_dataset_dict["pv"] = pv_datapipe + + return finished_dataset_dict + + +def windnet_datapipe( + config_filename: str, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + block_sat: bool = False, + block_nwp: bool = False, +) -> IterDataPipe: + """ + Construct windnet pipeline for the input data config file. + + Args: + config_filename: Path to config file. + start_time: Minimum time at which a sample can be selected. + end_time: Maximum time at which a sample can be selected. + block_sat: Whether to load zeroes for satellite data. + block_nwp: Whether to load zeroes for NWP data. + """ + logger.info("Constructing windnet pipeline") + + # Open datasets from the config and filter to useable location-time pairs + location_pipe, t0_datapipe = construct_loctime_pipelines( + config_filename, + start_time, + end_time, + block_sat, + block_nwp, + ) + + # Shard after we have the loc-times. These are already shuffled so no need to shuffle again + location_pipe = location_pipe.sharding_filter() + t0_datapipe = t0_datapipe.sharding_filter() + + # In this function we re-open the datasets to make a clean separation before/after sharding + # This function + datapipe_dict = construct_sliced_data_pipeline( + config_filename, + location_pipe, + t0_datapipe, + block_sat, + block_nwp, + ) + + # Save out datapipe to NetCDF + + # Merge all the datapipes into one + return DictDatasetIterDataPipe( + datapipe_dict["gsp"], + datapipe_dict["nwp"], + datapipe_dict["sat"], + datapipe_dict["pv"], + keys=["gsp", "nwp", "sat", "pv"], + ).map(combine_to_single_dataset) + + +def split_dataset_dict_dp(element): + """ + Split the dictionary of datapipes into individual datapipes + + Args: + element: Dictionary of datapipes + """ + return {k: IterableWrapper([v]) for k, v in element.items() if k != "config"} + + +def windnet_netcdf_datapipe( + config_filename: str, + keys: List[str], + filenames: List[str], + block_sat: bool = False, + block_nwp: bool = False, +) -> IterDataPipe: + """ + Load the saved Datapipes from windnet, and transform to numpy batch + + Args: + config_filename: Path to config file. + keys: List of keys to extract from the single NetCDF files + filenames: List of NetCDF files to load + block_sat: Whether to load zeroes for satellite data. + block_nwp: Whether to load zeroes for NWP data. + + Returns: + Datapipe that transforms the NetCDF files to numpy batch + """ + logger.info("Constructing windnet file pipeline") + config_datapipe = OpenConfiguration(config_filename) + configuration: Configuration = next(iter(config_datapipe)) + # Load files + datapipe_dict_dp: IterDataPipe = LoadDictDatasetIterDataPipe( + filenames=filenames, + keys=keys, + ).map(split_dataset_dict_dp) + datapipe = datapipe_dict_dp.convert_to_numpy_batch( + block_nwp=block_nwp, block_sat=block_sat, configuration=configuration + ) + + return datapipe diff --git a/ocf_datapipes/utils/utils.py b/ocf_datapipes/utils/utils.py index edbbf907c..66522b334 100644 --- a/ocf_datapipes/utils/utils.py +++ b/ocf_datapipes/utils/utils.py @@ -329,3 +329,86 @@ def trigonometric_datetime_transformation(datetimes: npt.ArrayLike) -> np.ndarra return np.concatenate( [sine_month, cosine_month, sine_day, cosine_day, sine_hour, cosine_hour], axis=1 ) + + +def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset: + """ + Combine multiple datasets into a single dataset + + Args: + dataset_dict: Dictionary of xr.Dataset objects to combine + + Returns: + Combined dataset + """ + # Convert all data_arrays to datasets + new_dataset_dict = {} + for key, datasets in dataset_dict.items(): + new_datasets = [] + for dataset in datasets: + # Convert all coordinates float64 and int64 to float32 and int32 + dataset = dataset.assign_attrs( + {key: str(value) for key, value in dataset.attrs.items()} + ) + if isinstance(dataset, xr.DataArray): + new_datasets.append(dataset.to_dataset(name=key)) + else: + new_datasets.append(dataset) + assert isinstance(new_datasets[-1], xr.Dataset) + new_dataset_dict[key] = new_datasets + # Prepend all coordinates and dimensions names with the key in the dataset_dict + final_datasets_to_combined = [] + for key, datasets in new_dataset_dict.items(): + batched_datasets = [] + for dataset in datasets: + dataset = dataset.rename( + {dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords} + ) + dataset = dataset.rename({coord: f"{key}__{coord}" for coord in dataset.coords}) + batched_datasets.append(dataset) + # Merge all datasets with the same key + # If NWP, then has init_time_utc and step, so do it off key__init_time_utc + dataset = xr.concat( + batched_datasets, + dim=f"{key}__target_time_utc" + if f"{key}__target_time_utc" in dataset.coords + else f"{key}__time_utc", + ) + # Serialize attributes to be JSON-seriaizable + final_datasets_to_combined.append(dataset) + # Combine all datasets, and append the list of datasets to the dataset_dict + for f_dset in final_datasets_to_combined: + assert isinstance(f_dset, xr.Dataset), f"Dataset is not an xr.Dataset, {type(f_dset)}" + combined_dataset = xr.merge(final_datasets_to_combined) + # Print all attrbutes of the combined dataset + return combined_dataset + + +def uncombine_from_single_dataset(combined_dataset: xr.Dataset) -> dict[str, xr.DataArray]: + """ + Uncombine a combined dataset + + Args: + combined_dataset: The combined NetCDF dataset + + Returns: + The uncombined datasets as a dict of xr.Datasets + """ + # Split into datasets by splitting by the prefix added in combine_to_netcdf + datasets = {} + # Go through each data variable and split it into a dataset + for key, dataset in combined_dataset.items(): + # If 'key_' doesn't exist in a dim or coordinate, remove it + dataset_dims = list(dataset.coords) + for dim in dataset_dims: + if f"{key}__" not in dim: + dataset: xr.DataArray = dataset.drop(dim) + dataset = dataset.rename( + {dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords} + ) + dataset: xr.Dataset = dataset.rename( + {coord: coord.split(f"{key}__")[1] for coord in dataset.coords} + ) + # Split the dataset by the prefix + datasets[key] = dataset + return datasets diff --git a/ocf_datapipes/validation/check_for_nans.py b/ocf_datapipes/validation/check_for_nans.py index 421ce2a83..dfea01198 100644 --- a/ocf_datapipes/validation/check_for_nans.py +++ b/ocf_datapipes/validation/check_for_nans.py @@ -25,6 +25,7 @@ def __init__( source_datapipe: Datapipe emitting Xarray Datasets dataset_name: Optional name for dataset to check, if None, checks whole dataset fill_nans: Whether to fill NaNs with 0 or not + fill_value: Value to fill NaNs with """ self.source_datapipe = source_datapipe self.dataset_name = dataset_name diff --git a/tests/training/test_pvnet.py b/tests/training/test_pvnet.py index 035aa493e..1c83b4fb5 100644 --- a/tests/training/test_pvnet.py +++ b/tests/training/test_pvnet.py @@ -3,10 +3,10 @@ from torchdata.datapipes.iter import IterableWrapper from ocf_datapipes.training.pvnet import ( - construct_loctime_pipelines, construct_sliced_data_pipeline, pvnet_datapipe, ) +from ocf_datapipes.training.common import construct_loctime_pipelines from ocf_datapipes.utils.consts import Location import pytest diff --git a/tests/training/test_windnet.py b/tests/training/test_windnet.py new file mode 100644 index 000000000..4a89cffe0 --- /dev/null +++ b/tests/training/test_windnet.py @@ -0,0 +1,27 @@ +from datetime import datetime + +from ocf_datapipes.training.windnet import ( + windnet_datapipe, + windnet_netcdf_datapipe, +) +from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset +import pytest + + +def test_windnet_datapipe(configuration_filename): + start_time = datetime(1900, 1, 1) + end_time = datetime(2050, 1, 1) + dp = windnet_datapipe( + configuration_filename, + start_time=start_time, + end_time=end_time, + ) + datasets = next(iter(dp)) + # Need to serialize attributes to strings + datasets.to_netcdf("test.nc", mode="w", engine="h5netcdf", compute=True) + dp = windnet_netcdf_datapipe( + config_filename=configuration_filename, + filenames=["test.nc"], + keys=["gsp", "nwp", "sat", "pv"], + ) + datasets = next(iter(dp)) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index d6792559b..df537f047 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -1,5 +1,9 @@ import numpy as np from ocf_datapipes.utils.utils import searchsorted +from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset +from ocf_datapipes.training.windnet import windnet_datapipe +from datetime import datetime +import xarray as xr def test_searchsorted(): @@ -7,3 +11,28 @@ def test_searchsorted(): assert searchsorted(ys, 2.1) == 2 ys_r = np.array([5, 4, 3, 2, 1], dtype=np.float32) assert searchsorted(ys_r, 2.1, assume_ascending=False) == 3 + + +def test_combine_uncombine_from_single_dataset(configuration_filename): + start_time = datetime(1900, 1, 1) + end_time = datetime(2050, 1, 1) + dp = windnet_datapipe( + configuration_filename, + start_time=start_time, + end_time=end_time, + ) + dataset: xr.Dataset = next(iter(dp)) + assert isinstance(dataset, xr.Dataset) + multiple_datasets = uncombine_from_single_dataset(dataset) + for key in multiple_datasets.keys(): + if "time_utc" in multiple_datasets[key].coords.keys(): + time_coord = "time_utc" + else: + time_coord = "target_time_utc" + for i in range(len(multiple_datasets[key][time_coord])): + # Assert that data for each of the coords is the same + for coord_key in multiple_datasets[key][i].coords.keys(): + np.testing.assert_equal( + multiple_datasets[key].isel({time_coord: i})[coord_key].values, + dataset[key][i][f"{key}__{coord_key}"].values, + )