diff --git a/ocf_datapipes/training/metnet_pv_site.py b/ocf_datapipes/training/metnet_pv_site.py index 601256b0f..9ac34aac3 100644 --- a/ocf_datapipes/training/metnet_pv_site.py +++ b/ocf_datapipes/training/metnet_pv_site.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Union +import numpy as np import xarray from torchdata.datapipes.iter import IterDataPipe @@ -44,6 +45,12 @@ def _load_xarray_values(x): return x.load() +def _filepath_fn(xr_data): + # Get filepath from metadata, including time, and location and return it + file_name = f"{np.random.randint(10000000)}.npy" + return file_name + + def metnet_site_datapipe( configuration_filename: Union[Path, str], use_sun: bool = True, @@ -59,6 +66,7 @@ def metnet_site_datapipe( center_size_meters: int = 64_000, context_size_meters: int = 512_000, batch_size: int = 1, + cache_to_disk: bool = False, ) -> IterDataPipe: """ Make GSP national data pipe @@ -80,6 +88,7 @@ def metnet_site_datapipe( center_size_meters: Center size for MeNet cutouts, in meters context_size_meters: Context area size in meters batch_size: Batch size for the datapipe + cache_to_disk: Whether to cache to disk or not Returns: datapipe """ @@ -108,7 +117,8 @@ def metnet_site_datapipe( pv_datapipe = used_datapipes["pv_future"].normalize(normalize_fn=normalize_pv) # Split into GSP for target, only national, and one for history pv_datapipe, pv_loc_datapipe = pv_datapipe.fork(2) - pv_loc_datapipe, pv_id_datapipe = LocationPicker(pv_loc_datapipe).fork(2) + pv_loc_datapipe = LocationPicker(pv_loc_datapipe) + pv_loc_datapipe, pv_id_datapipe = pv_loc_datapipe.fork(2) pv_history = pv_history.select_id(pv_id_datapipe, data_source_name="pv") if "nwp" in used_datapipes.keys(): @@ -202,13 +212,24 @@ def metnet_site_datapipe( output_height_pixels=output_size, add_sun_features=use_sun, ) + pv_datapipe = ConvertPVToNumpy(pv_datapipe) if not pv_in_image: pv_history = pv_history.map(_remove_nans) pv_history = ConvertPVToNumpy(pv_history, return_pv_id=True) - return metnet_datapipe.batch(batch_size).zip_ocf( - pv_history.batch(batch_size), pv_datapipe.batch(batch_size) + combined_datapipe = ( + metnet_datapipe.batch(batch_size) + .zip_ocf(pv_history.batch(batch_size), pv_datapipe.batch(batch_size)) + .on_disk_cache(filepath_fn=_filepath_fn) ) else: - return metnet_datapipe.batch(batch_size).zip_ocf(pv_datapipe.batch(batch_size)) + combined_datapipe = ( + metnet_datapipe.batch(batch_size) + .zip_ocf(pv_datapipe.batch(batch_size)) + .on_disk_cache(filepath_fn=_filepath_fn) + ) + + if cache_to_disk: + combined_datapipe = combined_datapipe.end_caching() + return combined_datapipe diff --git a/tests/training/test_metnet_pv_site.py b/tests/training/test_metnet_pv_site.py index b289d6f36..f9df46a58 100644 --- a/tests/training/test_metnet_pv_site.py +++ b/tests/training/test_metnet_pv_site.py @@ -15,3 +15,14 @@ def test_metnet_datapipe(): batch = next(iter(gsp_datapipe)) assert np.isfinite(batch[0]).all() assert np.isfinite(batch[1]).all() + + +def test_metnet_datapipe_cache(): + filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") + gsp_datapipe = metnet_site_datapipe( + filename, use_nwp=False, pv_in_image=True, cache_to_disk=True + ) + + batch = next(iter(gsp_datapipe)) + assert np.isfinite(batch[0]).all() + assert np.isfinite(batch[1]).all()