Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 5e6ed95

Browse files
committed
Base configuration off config file
1 parent 4459994 commit 5e6ed95

File tree

6 files changed

+84
-34
lines changed

6 files changed

+84
-34
lines changed

ocf_datapipes/training/common.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -542,21 +542,18 @@ def _get_datapipes_dict(
542542
config_filename: str,
543543
block_sat: bool,
544544
block_nwp: bool,
545-
block_sensor: bool = False,
546-
block_gsp: bool = False,
547-
block_pv: bool = False,
548545
production: bool = False,
549546
):
550547
# Load datasets
551548
datapipes_dict = open_and_return_datapipes(
552549
configuration_filename=config_filename,
553-
use_gsp=(not production and not block_gsp),
554-
use_pv=(not production and not block_pv),
550+
use_gsp=(not production),
551+
use_pv=(not production),
555552
use_sat=not block_sat, # Only loaded if we aren't replacing them with zeros
556-
use_hrv=False,
553+
use_hrv=True,
557554
use_nwp=not block_nwp, # Only loaded if we aren't replacing them with zeros
558-
use_topo=False,
559-
use_sensor=not block_sensor,
555+
use_topo=True,
556+
use_sensor=True,
560557
production=production,
561558
)
562559

@@ -586,9 +583,6 @@ def construct_loctime_pipelines(
586583
end_time: Optional[datetime] = None,
587584
block_sat: bool = False,
588585
block_nwp: bool = False,
589-
block_sensor: bool = False,
590-
block_gsp: bool = False,
591-
block_pv: bool = False,
592586
) -> Tuple[IterDataPipe, IterDataPipe]:
593587
"""Construct location and time pipelines for the input data config file.
594588
@@ -604,9 +598,6 @@ def construct_loctime_pipelines(
604598
config_filename,
605599
block_sat=block_sat,
606600
block_nwp=block_nwp,
607-
block_gsp=block_gsp,
608-
block_pv=block_pv,
609-
block_sensor=block_sensor,
610601
)
611602

612603
# Pull out config file

ocf_datapipes/training/windnet.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,6 @@ def construct_sliced_data_pipeline(
213213
config_filename,
214214
block_sat,
215215
block_nwp,
216-
block_sensor=False,
217-
block_gsp=True,
218-
block_pv=True,
219216
production=production,
220217
)
221218

@@ -252,13 +249,13 @@ def construct_sliced_data_pipeline(
252249

253250
if "sensor" in datapipes_dict:
254251
# Recombine Sensor arrays - see function doc for further explanation
255-
pv_datapipe = (
252+
sensor_datapipe = (
256253
datapipes_dict["sensor"]
257254
.zip_ocf(datapipes_dict["sensor_future"])
258255
.map(concat_xr_time_utc)
259256
)
260-
pv_datapipe = pv_datapipe.normalize(normalize_fn=_normalize_wind_speed)
261-
pv_datapipe = pv_datapipe.map(fill_nans_in_pv)
257+
sensor_datapipe = sensor_datapipe.normalize(normalize_fn=_normalize_wind_speed)
258+
sensor_datapipe = sensor_datapipe.map(fill_nans_in_pv)
262259

263260
finished_dataset_dict = {"config": configuration}
264261
# GSP always assumed to be in data
@@ -289,10 +286,8 @@ def construct_sliced_data_pipeline(
289286
finished_dataset_dict["nwp"] = nwp_datapipe
290287
if "sat" in datapipes_dict:
291288
finished_dataset_dict["sat"] = sat_datapipe
292-
if "pv" in datapipes_dict:
293-
finished_dataset_dict["pv"] = pv_datapipe
294289
if "sensor" in datapipes_dict:
295-
finished_dataset_dict["sensor"] = pv_datapipe
290+
finished_dataset_dict["sensor"] = sensor_datapipe
296291

297292
return finished_dataset_dict
298293

@@ -303,8 +298,6 @@ def windnet_datapipe(
303298
end_time: Optional[datetime] = None,
304299
block_sat: bool = False,
305300
block_nwp: bool = False,
306-
block_sensor: bool = False,
307-
block_pv: bool = True,
308301
) -> IterDataPipe:
309302
"""
310303
Construct windnet pipeline for the input data config file.
@@ -325,9 +318,6 @@ def windnet_datapipe(
325318
end_time,
326319
block_sat=block_sat,
327320
block_nwp=block_nwp,
328-
block_sensor=block_sensor,
329-
block_gsp=True,
330-
block_pv=block_pv,
331321
)
332322

333323
# Shard after we have the loc-times. These are already shuffled so no need to shuffle again

tests/config/wind_test.yaml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
general:
2+
description: !ENV example configuration with env ${PATH} # noqa
3+
name: example
4+
git: null
5+
input_data:
6+
nwp:
7+
nwp_channels:
8+
- t
9+
nwp_image_size_pixels_height: 2
10+
nwp_image_size_pixels_width: 2
11+
nwp_zarr_path: tests/data/nwp_data/test.zarr
12+
nwp_provider: "ukv"
13+
history_minutes: 60
14+
forecast_minutes: 120
15+
time_resolution_minutes: 60
16+
index_by_id: True
17+
pv:
18+
pv_files_groups:
19+
- label: solar_sheffield_passiv
20+
pv_filename: tests/data/pv/passiv/test.nc
21+
pv_metadata_filename: tests/data/pv/passiv/UK_PV_metadata.csv
22+
- label: pvoutput.org
23+
pv_filename: tests/data/pv/pvoutput/test.nc
24+
pv_metadata_filename: tests/data/pv/pvoutput/UK_PV_metadata.csv
25+
get_center: false
26+
pv_image_size_meters_height: 10000000
27+
pv_image_size_meters_width: 10000000
28+
n_pv_systems_per_example: 32
29+
start_datetime: "2010-01-01 00:00:00"
30+
end_datetime: "2030-01-01 00:00:00"
31+
pv_ml_ids: []
32+
satellite:
33+
satellite_channels:
34+
- IR_016
35+
satellite_image_size_pixels_height: 24
36+
satellite_image_size_pixels_width: 24
37+
satellite_zarr_path: tests/data/sat_data.zarr
38+
hrvsatellite:
39+
hrvsatellite_channels:
40+
- HRV
41+
hrvsatellite_image_size_pixels_height: 64
42+
hrvsatellite_image_size_pixels_width: 64
43+
hrvsatellite_zarr_path: tests/data/hrv_sat_data.zarr
44+
history_minutes: 30
45+
forecast_minutes: 60
46+
output_data:
47+
filepath: not used by unittests!
48+
process:
49+
batch_size: 4
50+
local_temp_path: ~/temp/
51+
seed: 1234
52+
upload_every_n_batches: 16
53+
n_train_batches: 2
54+
n_validation_batches: 0
55+
n_test_batches: 0
56+
train_test_validation_split: [3, 0, 1]

tests/conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,15 @@ def configuration():
597597
return load_yaml_configuration(filename=filename)
598598

599599

600+
@pytest.fixture()
601+
def configuration_no_gsp():
602+
filename = os.path.join(
603+
os.path.dirname(ocf_datapipes.__file__), "../tests/config/wind_test.yaml"
604+
)
605+
606+
return load_yaml_configuration(filename=filename)
607+
608+
600609
@pytest.fixture()
601610
def configuration_with_pv_netcdf(pv_netcdf_file):
602611
filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml")
@@ -651,3 +660,11 @@ def configuration_with_gsp_and_nwp(gsp_zarr_file, nwp_data_with_id_filename):
651660
def configuration_filename():
652661
filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml")
653662
yield filename
663+
664+
665+
@pytest.fixture()
666+
def wind_configuration_filename():
667+
filename = os.path.join(
668+
os.path.dirname(ocf_datapipes.__file__), "../tests/config/wind_test.yaml"
669+
)
670+
yield filename

tests/training/test_windnet.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ def test_windnet_datapipe(configuration_filename):
1515
configuration_filename,
1616
start_time=start_time,
1717
end_time=end_time,
18-
block_sensor=True,
19-
block_pv=False,
2018
)
2119
datasets = next(iter(dp))
2220
# Need to serialize attributes to strings

tests/utils/test_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@ def test_searchsorted():
1313
assert searchsorted(ys_r, 2.1, assume_ascending=False) == 3
1414

1515

16-
def test_combine_uncombine_from_single_dataset(configuration_filename):
16+
def test_combine_uncombine_from_single_dataset(wind_configuration_filename):
1717
start_time = datetime(1900, 1, 1)
1818
end_time = datetime(2050, 1, 1)
1919
dp = windnet_datapipe(
20-
configuration_filename,
20+
wind_configuration_filename,
2121
start_time=start_time,
2222
end_time=end_time,
23-
block_sensor=True,
24-
block_pv=False,
2523
)
2624
dataset: xr.Dataset = next(iter(dp))
2725
assert isinstance(dataset, xr.Dataset)

0 commit comments

Comments
 (0)