Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Caching to Disk in Datapipe #182

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions ocf_datapipes/training/metnet_pv_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Union

import numpy as np
import xarray
from torchdata.datapipes.iter import IterDataPipe

Expand Down Expand Up @@ -44,6 +45,12 @@ def _load_xarray_values(x):
return x.load()


def _filepath_fn(xr_data):
# Get filepath from metadata, including time, and location and return it
file_name = f"{np.random.randint(10000000)}.npy"
return file_name


def metnet_site_datapipe(
configuration_filename: Union[Path, str],
use_sun: bool = True,
Expand All @@ -59,6 +66,7 @@ def metnet_site_datapipe(
center_size_meters: int = 64_000,
context_size_meters: int = 512_000,
batch_size: int = 1,
cache_to_disk: bool = False,
) -> IterDataPipe:
"""
Make GSP national data pipe
Expand All @@ -80,6 +88,7 @@ def metnet_site_datapipe(
center_size_meters: Center size for MeNet cutouts, in meters
context_size_meters: Context area size in meters
batch_size: Batch size for the datapipe
cache_to_disk: Whether to cache to disk or not

Returns: datapipe
"""
Expand Down Expand Up @@ -108,7 +117,8 @@ def metnet_site_datapipe(
pv_datapipe = used_datapipes["pv_future"].normalize(normalize_fn=normalize_pv)
# Split into GSP for target, only national, and one for history
pv_datapipe, pv_loc_datapipe = pv_datapipe.fork(2)
pv_loc_datapipe, pv_id_datapipe = LocationPicker(pv_loc_datapipe).fork(2)
pv_loc_datapipe = LocationPicker(pv_loc_datapipe)
pv_loc_datapipe, pv_id_datapipe = pv_loc_datapipe.fork(2)
pv_history = pv_history.select_id(pv_id_datapipe, data_source_name="pv")

if "nwp" in used_datapipes.keys():
Expand Down Expand Up @@ -202,13 +212,24 @@ def metnet_site_datapipe(
output_height_pixels=output_size,
add_sun_features=use_sun,
)

pv_datapipe = ConvertPVToNumpy(pv_datapipe)

if not pv_in_image:
pv_history = pv_history.map(_remove_nans)
pv_history = ConvertPVToNumpy(pv_history, return_pv_id=True)
return metnet_datapipe.batch(batch_size).zip_ocf(
pv_history.batch(batch_size), pv_datapipe.batch(batch_size)
combined_datapipe = (
metnet_datapipe.batch(batch_size)
.zip_ocf(pv_history.batch(batch_size), pv_datapipe.batch(batch_size))
.on_disk_cache(filepath_fn=_filepath_fn)
)
else:
return metnet_datapipe.batch(batch_size).zip_ocf(pv_datapipe.batch(batch_size))
combined_datapipe = (
metnet_datapipe.batch(batch_size)
.zip_ocf(pv_datapipe.batch(batch_size))
.on_disk_cache(filepath_fn=_filepath_fn)
)

if cache_to_disk:
combined_datapipe = combined_datapipe.end_caching()
return combined_datapipe
11 changes: 11 additions & 0 deletions tests/training/test_metnet_pv_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,14 @@ def test_metnet_datapipe():
batch = next(iter(gsp_datapipe))
assert np.isfinite(batch[0]).all()
assert np.isfinite(batch[1]).all()


def test_metnet_datapipe_cache():
filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml")
gsp_datapipe = metnet_site_datapipe(
filename, use_nwp=False, pv_in_image=True, cache_to_disk=True
)

batch = next(iter(gsp_datapipe))
assert np.isfinite(batch[0]).all()
assert np.isfinite(batch[1]).all()