From 3531b696e411e7d80e3d787483ceef6dce4cb8c6 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Mon, 3 Apr 2023 10:22:37 +0100 Subject: [PATCH 1/4] Add optional caching to metnet datapipe --- ocf_datapipes/training/metnet_pv_site.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/ocf_datapipes/training/metnet_pv_site.py b/ocf_datapipes/training/metnet_pv_site.py index 601256b0f..f2c8f1791 100644 --- a/ocf_datapipes/training/metnet_pv_site.py +++ b/ocf_datapipes/training/metnet_pv_site.py @@ -43,6 +43,10 @@ def _remove_nans(x): def _load_xarray_values(x): return x.load() +def _filepath_fn(xr_data): + # Get filepath from metadata, including time, and location and return it + file_name = f"{xr_data.time.values[0]}_{xr_data.pv_system_id.values[0]}_{xr_data.x_osgb.values[0]}_{xr_data.y_osgb.values[0]}.npy" + return file_name def metnet_site_datapipe( configuration_filename: Union[Path, str], @@ -59,6 +63,7 @@ def metnet_site_datapipe( center_size_meters: int = 64_000, context_size_meters: int = 512_000, batch_size: int = 1, + cache_to_disk: bool = False, ) -> IterDataPipe: """ Make GSP national data pipe @@ -80,6 +85,7 @@ def metnet_site_datapipe( center_size_meters: Center size for MeNet cutouts, in meters context_size_meters: Context area size in meters batch_size: Batch size for the datapipe + cache_to_disk: Whether to cache to disk or not Returns: datapipe """ @@ -108,8 +114,11 @@ def metnet_site_datapipe( pv_datapipe = used_datapipes["pv_future"].normalize(normalize_fn=normalize_pv) # Split into GSP for target, only national, and one for history pv_datapipe, pv_loc_datapipe = pv_datapipe.fork(2) - pv_loc_datapipe, pv_id_datapipe = LocationPicker(pv_loc_datapipe).fork(2) + pv_loc_datapipe = LocationPicker(pv_loc_datapipe) + pv_loc_datapipe, pv_id_datapipe = pv_loc_datapipe.fork(2) pv_history = pv_history.select_id(pv_id_datapipe, data_source_name="pv") + if cache_to_disk: + pv_history.on_disk_cache(filepath_fn=_filepath_fn) if "nwp" in used_datapipes.keys(): # take nwp time slices @@ -202,13 +211,18 @@ def metnet_site_datapipe( output_height_pixels=output_size, add_sun_features=use_sun, ) + pv_datapipe = ConvertPVToNumpy(pv_datapipe) if not pv_in_image: pv_history = pv_history.map(_remove_nans) pv_history = ConvertPVToNumpy(pv_history, return_pv_id=True) - return metnet_datapipe.batch(batch_size).zip_ocf( + combined_datapipe = metnet_datapipe.batch(batch_size).zip_ocf( pv_history.batch(batch_size), pv_datapipe.batch(batch_size) ) else: - return metnet_datapipe.batch(batch_size).zip_ocf(pv_datapipe.batch(batch_size)) + combined_datapipe = metnet_datapipe.batch(batch_size).zip_ocf(pv_datapipe.batch(batch_size)) + + if cache_to_disk: + combined_datapipe = combined_datapipe.end_caching() + return combined_datapipe From 8d085aacb3f4cc211efb8ad1a3f340f1a9841dbe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Apr 2023 09:24:19 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/training/metnet_pv_site.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ocf_datapipes/training/metnet_pv_site.py b/ocf_datapipes/training/metnet_pv_site.py index f2c8f1791..897b7fb7a 100644 --- a/ocf_datapipes/training/metnet_pv_site.py +++ b/ocf_datapipes/training/metnet_pv_site.py @@ -43,11 +43,13 @@ def _remove_nans(x): def _load_xarray_values(x): return x.load() + def _filepath_fn(xr_data): # Get filepath from metadata, including time, and location and return it file_name = f"{xr_data.time.values[0]}_{xr_data.pv_system_id.values[0]}_{xr_data.x_osgb.values[0]}_{xr_data.y_osgb.values[0]}.npy" return file_name + def metnet_site_datapipe( configuration_filename: Union[Path, str], use_sun: bool = True, From 3bc5baf0576722e2963a34ab3715faeee759588d Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 4 Apr 2023 17:17:45 +0100 Subject: [PATCH 3/4] More trying caching Currently fails when turned on but works when not. --- ocf_datapipes/training/metnet_pv_site.py | 9 ++++----- tests/training/test_metnet_pv_site.py | 8 ++++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/ocf_datapipes/training/metnet_pv_site.py b/ocf_datapipes/training/metnet_pv_site.py index 897b7fb7a..e9fc1b106 100644 --- a/ocf_datapipes/training/metnet_pv_site.py +++ b/ocf_datapipes/training/metnet_pv_site.py @@ -3,6 +3,7 @@ import logging from pathlib import Path from typing import Union +import numpy as np import xarray from torchdata.datapipes.iter import IterDataPipe @@ -46,7 +47,7 @@ def _load_xarray_values(x): def _filepath_fn(xr_data): # Get filepath from metadata, including time, and location and return it - file_name = f"{xr_data.time.values[0]}_{xr_data.pv_system_id.values[0]}_{xr_data.x_osgb.values[0]}_{xr_data.y_osgb.values[0]}.npy" + file_name = f"{np.random.randint(10000000)}.npy" return file_name @@ -119,8 +120,6 @@ def metnet_site_datapipe( pv_loc_datapipe = LocationPicker(pv_loc_datapipe) pv_loc_datapipe, pv_id_datapipe = pv_loc_datapipe.fork(2) pv_history = pv_history.select_id(pv_id_datapipe, data_source_name="pv") - if cache_to_disk: - pv_history.on_disk_cache(filepath_fn=_filepath_fn) if "nwp" in used_datapipes.keys(): # take nwp time slices @@ -221,9 +220,9 @@ def metnet_site_datapipe( pv_history = ConvertPVToNumpy(pv_history, return_pv_id=True) combined_datapipe = metnet_datapipe.batch(batch_size).zip_ocf( pv_history.batch(batch_size), pv_datapipe.batch(batch_size) - ) + ).on_disk_cache(filepath_fn=_filepath_fn) else: - combined_datapipe = metnet_datapipe.batch(batch_size).zip_ocf(pv_datapipe.batch(batch_size)) + combined_datapipe = metnet_datapipe.batch(batch_size).zip_ocf(pv_datapipe.batch(batch_size)).on_disk_cache(filepath_fn=_filepath_fn) if cache_to_disk: combined_datapipe = combined_datapipe.end_caching() diff --git a/tests/training/test_metnet_pv_site.py b/tests/training/test_metnet_pv_site.py index b289d6f36..0a2c2f5df 100644 --- a/tests/training/test_metnet_pv_site.py +++ b/tests/training/test_metnet_pv_site.py @@ -15,3 +15,11 @@ def test_metnet_datapipe(): batch = next(iter(gsp_datapipe)) assert np.isfinite(batch[0]).all() assert np.isfinite(batch[1]).all() + +def test_metnet_datapipe_cache(): + filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") + gsp_datapipe = metnet_site_datapipe(filename, use_nwp=False, pv_in_image=True, cache_to_disk=True) + + batch = next(iter(gsp_datapipe)) + assert np.isfinite(batch[0]).all() + assert np.isfinite(batch[1]).all() From 4a2157407588ff1eaddd906c3688d5116901ddec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Apr 2023 16:18:02 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/training/metnet_pv_site.py | 16 +++++++++++----- tests/training/test_metnet_pv_site.py | 5 ++++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/ocf_datapipes/training/metnet_pv_site.py b/ocf_datapipes/training/metnet_pv_site.py index e9fc1b106..9ac34aac3 100644 --- a/ocf_datapipes/training/metnet_pv_site.py +++ b/ocf_datapipes/training/metnet_pv_site.py @@ -3,8 +3,8 @@ import logging from pathlib import Path from typing import Union -import numpy as np +import numpy as np import xarray from torchdata.datapipes.iter import IterDataPipe @@ -218,11 +218,17 @@ def metnet_site_datapipe( if not pv_in_image: pv_history = pv_history.map(_remove_nans) pv_history = ConvertPVToNumpy(pv_history, return_pv_id=True) - combined_datapipe = metnet_datapipe.batch(batch_size).zip_ocf( - pv_history.batch(batch_size), pv_datapipe.batch(batch_size) - ).on_disk_cache(filepath_fn=_filepath_fn) + combined_datapipe = ( + metnet_datapipe.batch(batch_size) + .zip_ocf(pv_history.batch(batch_size), pv_datapipe.batch(batch_size)) + .on_disk_cache(filepath_fn=_filepath_fn) + ) else: - combined_datapipe = metnet_datapipe.batch(batch_size).zip_ocf(pv_datapipe.batch(batch_size)).on_disk_cache(filepath_fn=_filepath_fn) + combined_datapipe = ( + metnet_datapipe.batch(batch_size) + .zip_ocf(pv_datapipe.batch(batch_size)) + .on_disk_cache(filepath_fn=_filepath_fn) + ) if cache_to_disk: combined_datapipe = combined_datapipe.end_caching() diff --git a/tests/training/test_metnet_pv_site.py b/tests/training/test_metnet_pv_site.py index 0a2c2f5df..f9df46a58 100644 --- a/tests/training/test_metnet_pv_site.py +++ b/tests/training/test_metnet_pv_site.py @@ -16,9 +16,12 @@ def test_metnet_datapipe(): assert np.isfinite(batch[0]).all() assert np.isfinite(batch[1]).all() + def test_metnet_datapipe_cache(): filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") - gsp_datapipe = metnet_site_datapipe(filename, use_nwp=False, pv_in_image=True, cache_to_disk=True) + gsp_datapipe = metnet_site_datapipe( + filename, use_nwp=False, pv_in_image=True, cache_to_disk=True + ) batch = next(iter(gsp_datapipe)) assert np.isfinite(batch[0]).all()