Skip to content

Commit

Permalink
Base configuration off config file
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Nov 24, 2023
1 parent 4459994 commit 5e6ed95
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 34 deletions.
19 changes: 5 additions & 14 deletions ocf_datapipes/training/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,21 +542,18 @@ def _get_datapipes_dict(
config_filename: str,
block_sat: bool,
block_nwp: bool,
block_sensor: bool = False,
block_gsp: bool = False,
block_pv: bool = False,
production: bool = False,
):
# Load datasets
datapipes_dict = open_and_return_datapipes(
configuration_filename=config_filename,
use_gsp=(not production and not block_gsp),
use_pv=(not production and not block_pv),
use_gsp=(not production),
use_pv=(not production),
use_sat=not block_sat, # Only loaded if we aren't replacing them with zeros
use_hrv=False,
use_hrv=True,
use_nwp=not block_nwp, # Only loaded if we aren't replacing them with zeros
use_topo=False,
use_sensor=not block_sensor,
use_topo=True,
use_sensor=True,
production=production,
)

Expand Down Expand Up @@ -586,9 +583,6 @@ def construct_loctime_pipelines(
end_time: Optional[datetime] = None,
block_sat: bool = False,
block_nwp: bool = False,
block_sensor: bool = False,
block_gsp: bool = False,
block_pv: bool = False,
) -> Tuple[IterDataPipe, IterDataPipe]:
"""Construct location and time pipelines for the input data config file.
Expand All @@ -604,9 +598,6 @@ def construct_loctime_pipelines(
config_filename,
block_sat=block_sat,
block_nwp=block_nwp,
block_gsp=block_gsp,
block_pv=block_pv,
block_sensor=block_sensor,
)

# Pull out config file
Expand Down
18 changes: 4 additions & 14 deletions ocf_datapipes/training/windnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,6 @@ def construct_sliced_data_pipeline(
config_filename,
block_sat,
block_nwp,
block_sensor=False,
block_gsp=True,
block_pv=True,
production=production,
)

Expand Down Expand Up @@ -252,13 +249,13 @@ def construct_sliced_data_pipeline(

if "sensor" in datapipes_dict:
# Recombine Sensor arrays - see function doc for further explanation
pv_datapipe = (
sensor_datapipe = (
datapipes_dict["sensor"]
.zip_ocf(datapipes_dict["sensor_future"])
.map(concat_xr_time_utc)
)
pv_datapipe = pv_datapipe.normalize(normalize_fn=_normalize_wind_speed)
pv_datapipe = pv_datapipe.map(fill_nans_in_pv)
sensor_datapipe = sensor_datapipe.normalize(normalize_fn=_normalize_wind_speed)
sensor_datapipe = sensor_datapipe.map(fill_nans_in_pv)

finished_dataset_dict = {"config": configuration}
# GSP always assumed to be in data
Expand Down Expand Up @@ -289,10 +286,8 @@ def construct_sliced_data_pipeline(
finished_dataset_dict["nwp"] = nwp_datapipe
if "sat" in datapipes_dict:
finished_dataset_dict["sat"] = sat_datapipe
if "pv" in datapipes_dict:
finished_dataset_dict["pv"] = pv_datapipe
if "sensor" in datapipes_dict:
finished_dataset_dict["sensor"] = pv_datapipe
finished_dataset_dict["sensor"] = sensor_datapipe

return finished_dataset_dict

Expand All @@ -303,8 +298,6 @@ def windnet_datapipe(
end_time: Optional[datetime] = None,
block_sat: bool = False,
block_nwp: bool = False,
block_sensor: bool = False,
block_pv: bool = True,
) -> IterDataPipe:
"""
Construct windnet pipeline for the input data config file.
Expand All @@ -325,9 +318,6 @@ def windnet_datapipe(
end_time,
block_sat=block_sat,
block_nwp=block_nwp,
block_sensor=block_sensor,
block_gsp=True,
block_pv=block_pv,
)

# Shard after we have the loc-times. These are already shuffled so no need to shuffle again
Expand Down
56 changes: 56 additions & 0 deletions tests/config/wind_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
general:
description: !ENV example configuration with env ${PATH} # noqa
name: example
git: null
input_data:
nwp:
nwp_channels:
- t
nwp_image_size_pixels_height: 2
nwp_image_size_pixels_width: 2
nwp_zarr_path: tests/data/nwp_data/test.zarr
nwp_provider: "ukv"
history_minutes: 60
forecast_minutes: 120
time_resolution_minutes: 60
index_by_id: True
pv:
pv_files_groups:
- label: solar_sheffield_passiv
pv_filename: tests/data/pv/passiv/test.nc
pv_metadata_filename: tests/data/pv/passiv/UK_PV_metadata.csv
- label: pvoutput.org
pv_filename: tests/data/pv/pvoutput/test.nc
pv_metadata_filename: tests/data/pv/pvoutput/UK_PV_metadata.csv
get_center: false
pv_image_size_meters_height: 10000000
pv_image_size_meters_width: 10000000
n_pv_systems_per_example: 32
start_datetime: "2010-01-01 00:00:00"
end_datetime: "2030-01-01 00:00:00"
pv_ml_ids: []
satellite:
satellite_channels:
- IR_016
satellite_image_size_pixels_height: 24
satellite_image_size_pixels_width: 24
satellite_zarr_path: tests/data/sat_data.zarr
hrvsatellite:
hrvsatellite_channels:
- HRV
hrvsatellite_image_size_pixels_height: 64
hrvsatellite_image_size_pixels_width: 64
hrvsatellite_zarr_path: tests/data/hrv_sat_data.zarr
history_minutes: 30
forecast_minutes: 60
output_data:
filepath: not used by unittests!
process:
batch_size: 4
local_temp_path: ~/temp/
seed: 1234
upload_every_n_batches: 16
n_train_batches: 2
n_validation_batches: 0
n_test_batches: 0
train_test_validation_split: [3, 0, 1]
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,15 @@ def configuration():
return load_yaml_configuration(filename=filename)


@pytest.fixture()
def configuration_no_gsp():
filename = os.path.join(
os.path.dirname(ocf_datapipes.__file__), "../tests/config/wind_test.yaml"
)

return load_yaml_configuration(filename=filename)


@pytest.fixture()
def configuration_with_pv_netcdf(pv_netcdf_file):
filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml")
Expand Down Expand Up @@ -651,3 +660,11 @@ def configuration_with_gsp_and_nwp(gsp_zarr_file, nwp_data_with_id_filename):
def configuration_filename():
filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml")
yield filename


@pytest.fixture()
def wind_configuration_filename():
filename = os.path.join(
os.path.dirname(ocf_datapipes.__file__), "../tests/config/wind_test.yaml"
)
yield filename
2 changes: 0 additions & 2 deletions tests/training/test_windnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ def test_windnet_datapipe(configuration_filename):
configuration_filename,
start_time=start_time,
end_time=end_time,
block_sensor=True,
block_pv=False,
)
datasets = next(iter(dp))
# Need to serialize attributes to strings
Expand Down
6 changes: 2 additions & 4 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ def test_searchsorted():
assert searchsorted(ys_r, 2.1, assume_ascending=False) == 3


def test_combine_uncombine_from_single_dataset(configuration_filename):
def test_combine_uncombine_from_single_dataset(wind_configuration_filename):
start_time = datetime(1900, 1, 1)
end_time = datetime(2050, 1, 1)
dp = windnet_datapipe(
configuration_filename,
wind_configuration_filename,
start_time=start_time,
end_time=end_time,
block_sensor=True,
block_pv=False,
)
dataset: xr.Dataset = next(iter(dp))
assert isinstance(dataset, xr.Dataset)
Expand Down

0 comments on commit 5e6ed95

Please sign in to comment.