Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Add Caching to Disk in Datapipe #182

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions ocf_datapipes/training/metnet_pv_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Union

import numpy as np
import xarray
from torchdata.datapipes.iter import IterDataPipe

Expand Down Expand Up @@ -44,6 +45,12 @@ def _load_xarray_values(x):
return x.load()


def _filepath_fn(xr_data):
# Get filepath from metadata, including time, and location and return it
file_name = f"{np.random.randint(10000000)}.npy"
return file_name


def metnet_site_datapipe(
configuration_filename: Union[Path, str],
use_sun: bool = True,
Expand All @@ -59,6 +66,7 @@ def metnet_site_datapipe(
center_size_meters: int = 64_000,
context_size_meters: int = 512_000,
batch_size: int = 1,
cache_to_disk: bool = False,
) -> IterDataPipe:
"""
Make GSP national data pipe
Expand All @@ -80,6 +88,7 @@ def metnet_site_datapipe(
center_size_meters: Center size for MeNet cutouts, in meters
context_size_meters: Context area size in meters
batch_size: Batch size for the datapipe
cache_to_disk: Whether to cache to disk or not

Returns: datapipe
"""
Expand Down Expand Up @@ -108,7 +117,8 @@ def metnet_site_datapipe(
pv_datapipe = used_datapipes["pv_future"].normalize(normalize_fn=normalize_pv)
# Split into GSP for target, only national, and one for history
pv_datapipe, pv_loc_datapipe = pv_datapipe.fork(2)
pv_loc_datapipe, pv_id_datapipe = LocationPicker(pv_loc_datapipe).fork(2)
pv_loc_datapipe = LocationPicker(pv_loc_datapipe)
pv_loc_datapipe, pv_id_datapipe = pv_loc_datapipe.fork(2)
pv_history = pv_history.select_id(pv_id_datapipe, data_source_name="pv")

if "nwp" in used_datapipes.keys():
Expand Down Expand Up @@ -202,13 +212,24 @@ def metnet_site_datapipe(
output_height_pixels=output_size,
add_sun_features=use_sun,
)

pv_datapipe = ConvertPVToNumpy(pv_datapipe)

if not pv_in_image:
pv_history = pv_history.map(_remove_nans)
pv_history = ConvertPVToNumpy(pv_history, return_pv_id=True)
return metnet_datapipe.batch(batch_size).zip_ocf(
pv_history.batch(batch_size), pv_datapipe.batch(batch_size)
combined_datapipe = (
metnet_datapipe.batch(batch_size)
.zip_ocf(pv_history.batch(batch_size), pv_datapipe.batch(batch_size))
.on_disk_cache(filepath_fn=_filepath_fn)
)
else:
return metnet_datapipe.batch(batch_size).zip_ocf(pv_datapipe.batch(batch_size))
combined_datapipe = (
metnet_datapipe.batch(batch_size)
.zip_ocf(pv_datapipe.batch(batch_size))
.on_disk_cache(filepath_fn=_filepath_fn)
)

if cache_to_disk:
combined_datapipe = combined_datapipe.end_caching()
return combined_datapipe
11 changes: 11 additions & 0 deletions tests/training/test_metnet_pv_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,14 @@ def test_metnet_datapipe():
batch = next(iter(gsp_datapipe))
assert np.isfinite(batch[0]).all()
assert np.isfinite(batch[1]).all()


def test_metnet_datapipe_cache():
filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml")
gsp_datapipe = metnet_site_datapipe(
filename, use_nwp=False, pv_in_image=True, cache_to_disk=True
)

batch = next(iter(gsp_datapipe))
assert np.isfinite(batch[0]).all()
assert np.isfinite(batch[1]).all()