Skip to content

Commit

Permalink
update to nested NWP, clean, refactor, and fix tests #minor
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Dec 11, 2023
1 parent 5c5687a commit 56b243d
Show file tree
Hide file tree
Showing 71 changed files with 1,584 additions and 1,443 deletions.
2 changes: 1 addition & 1 deletion ocf_datapipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
import ocf_datapipes.select
import ocf_datapipes.transform
import ocf_datapipes.utils
import ocf_datapipes.validation
import ocf_datapipes.validation
7 changes: 7 additions & 0 deletions ocf_datapipes/batch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
"""Datapipes for batching together data"""
from .merge_numpy_examples_to_batch import (
stack_np_examples_into_batch,
unstack_np_batch_into_examples,
MergeNumpyBatchIterDataPipe as MergeNumpyBatch,
MergeNumpyExamplesToBatchIterDataPipe as MergeNumpyExamplesToBatch,
)
from .merge_numpy_modalities import (
MergeNumpyModalitiesIterDataPipe as MergeNumpyModalities,
MergeNWPNumpyModalitiesIterDataPipe as MergeNWPNumpyModalities
)
23 changes: 14 additions & 9 deletions ocf_datapipes/batch/fake/fake_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
from ocf_datapipes.utils.utils import datetime64_to_float


def make_fake_batch(configuration: Configuration, to_torch: Optional[bool] = False) -> dict:
def make_fake_batch(
configuration: Configuration,
batch_size: int = 8,
to_torch: Optional[bool] = False,
) -> dict:
"""
Make a random fake batch, this is useful for models that use this object
Args:
configuration: a configuration file
batch_size: the batch size
to_torch: optional if we return the batch with torch.Tensor
Returns: dictionary containing the batch
Expand All @@ -35,24 +40,24 @@ def make_fake_batch(configuration: Configuration, to_torch: Optional[bool] = Fal
t0_datetime_utc = t0_datetime_utc.replace(microsecond=0)

# make fake PV data
batch_pv = make_fake_pv_data(configuration=configuration, t0_datetime_utc=t0_datetime_utc)
batch_pv = make_fake_pv_data(configuration, t0_datetime_utc, batch_size)

# make NWP data
batch_nwp = make_fake_nwp_data(configuration=configuration, t0_datetime_utc=t0_datetime_utc)
batch_nwp = make_fake_nwp_data(configuration, t0_datetime_utc, batch_size)

# make GSP data
batch_gsp = make_fake_gsp_data(configuration=configuration, t0_datetime_utc=t0_datetime_utc)
batch_gsp = make_fake_gsp_data(configuration, t0_datetime_utc, batch_size)

# make hrv and normal satellite data
batch_satellite = make_fake_satellite_data(
configuration=configuration, t0_datetime_utc=t0_datetime_utc, is_hrv=False
configuration, t0_datetime_utc, is_hrv=False, batch_size=batch_size,
)
batch_hrv_satellite = make_fake_satellite_data(
configuration=configuration, t0_datetime_utc=t0_datetime_utc, is_hrv=True
configuration, t0_datetime_utc, is_hrv=True, batch_size=batch_size,
)

# make sun features
batch_sun = make_fake_sun_data(configuration=configuration)
batch_sun = make_fake_sun_data(configuration, batch_size)

batch = {
**batch_pv,
Expand All @@ -76,7 +81,7 @@ def make_fake_batch(configuration: Configuration, to_torch: Optional[bool] = Fal
return batch


def fake_data_pipeline(configuration: Union[str, Configuration]):
def fake_data_pipeline(configuration: Union[str, Configuration], batch_size: int = 8):
"""
Make a fake data pipeline
Expand All @@ -88,7 +93,7 @@ def fake_data_pipeline(configuration: Union[str, Configuration]):
if isinstance(configuration, str):
configuration = load_yaml_configuration(configuration)

batch = make_fake_batch(configuration=configuration, to_torch=True)
batch = make_fake_batch(configuration=configuration, to_torch=True, batch_size=batch_size)

def fake_iter():
while True:
Expand Down
7 changes: 5 additions & 2 deletions ocf_datapipes/batch/fake/gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from ocf_datapipes.utils.consts import BatchKey


def make_fake_gsp_data(configuration: Configuration, t0_datetime_utc: datetime) -> dict:
def make_fake_gsp_data(
configuration: Configuration,
t0_datetime_utc: datetime,
batch_size: int = 8
) -> dict:
"""
Make Fake GSP data ready for ML model
Expand All @@ -24,7 +28,6 @@ def make_fake_gsp_data(configuration: Configuration, t0_datetime_utc: datetime)
if gsp_config is None:
return {}

batch_size = configuration.process.batch_size
n_gsps = gsp_config.n_gsp_per_example
n_fourier_features = 8

Expand Down
98 changes: 57 additions & 41 deletions ocf_datapipes/batch/fake/nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

from ocf_datapipes.batch.fake.utils import get_n_time_steps_from_config, make_time_utc
from ocf_datapipes.config.model import Configuration
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.utils.consts import NWPBatchKey, BatchKey


def make_fake_nwp_data(configuration: Configuration, t0_datetime_utc: datetime):
def make_fake_nwp_data(
configuration: Configuration,
t0_datetime_utc: datetime,
batch_size: int = 8
) -> dict:
"""
Make Fake NWP data ready for ML model
Expand All @@ -20,49 +24,61 @@ def make_fake_nwp_data(configuration: Configuration, t0_datetime_utc: datetime):
"""

nwp_config = configuration.input_data.nwp
if nwp_config is None:

if configuration.input_data.nwp is None:
return {}

batch = {}

for nwp_source, nwp_config in configuration.input_data.nwp.items():

source_batch = {}

batch_size = configuration.process.batch_size
n_channels = len(nwp_config.nwp_channels)
n_y_osgb = nwp_config.nwp_image_size_pixels_height
n_x_osgb = nwp_config.nwp_image_size_pixels_width
n_fourier_features = 8
n_channels = len(nwp_config.nwp_channels)
n_y_osgb = nwp_config.nwp_image_size_pixels_height
n_x_osgb = nwp_config.nwp_image_size_pixels_width
n_fourier_features = 8

# make time matrix
time_utc = make_time_utc(
batch_size=batch_size,
history_minutes=nwp_config.history_minutes,
forecast_minutes=nwp_config.forecast_minutes,
t0_datetime_utc=t0_datetime_utc,
time_resolution_minutes=nwp_config.time_resolution_minutes,
)
n_times = time_utc.shape[1]
# make time matrix
time_utc = make_time_utc(
batch_size=batch_size,
history_minutes=nwp_config.history_minutes,
forecast_minutes=nwp_config.forecast_minutes,
t0_datetime_utc=t0_datetime_utc,
time_resolution_minutes=nwp_config.time_resolution_minutes,
)
n_times = time_utc.shape[1]

# main nwp components
batch = {}
batch[BatchKey.nwp_init_time_utc] = time_utc # Seconds since UNIX epoch (1970-01-01).
batch[BatchKey.nwp_target_time_utc] = time_utc # Seconds since UNIX epoch (1970-01-01).
batch[BatchKey.nwp] = np.random.random((batch_size, n_times, n_channels, n_y_osgb, n_x_osgb))
batch[BatchKey.nwp_t0_idx] = get_n_time_steps_from_config(
input_data_configuration=nwp_config, include_forecast=False
)
# main nwp components

source_batch[NWPBatchKey.nwp_init_time_utc] = time_utc # Seconds since UNIX epoch (1970-01-01).
source_batch[NWPBatchKey.nwp_target_time_utc] = time_utc # Seconds since UNIX epoch (1970-01-01).
source_batch[NWPBatchKey.nwp] = np.random.random(
(batch_size, n_times, n_channels, n_y_osgb, n_x_osgb)
)
source_batch[NWPBatchKey.nwp_t0_idx] = get_n_time_steps_from_config(
input_data_configuration=nwp_config, include_forecast=False
)

batch[BatchKey.nwp_step] = np.random.randint(0, 100, (batch_size, n_times))
batch[BatchKey.nwp_y_osgb] = np.random.randint(0, 100, (batch_size, n_y_osgb))
batch[BatchKey.nwp_x_osgb] = np.random.randint(0, 100, (batch_size, n_x_osgb))
batch[BatchKey.nwp_channel_names] = np.random.randint(0, 100, (n_channels,))
source_batch[NWPBatchKey.nwp_step] = np.random.randint(0, 100, (batch_size, n_times))
source_batch[NWPBatchKey.nwp_y_osgb] = np.random.randint(0, 100, (batch_size, n_y_osgb))
source_batch[NWPBatchKey.nwp_x_osgb] = np.random.randint(0, 100, (batch_size, n_x_osgb))
source_batch[NWPBatchKey.nwp_channel_names] = np.random.randint(0, 100, (n_channels,))

# fourier components
batch[BatchKey.nwp_x_osgb_fourier] = np.random.random(
(batch_size, n_x_osgb, n_fourier_features)
)
batch[BatchKey.nwp_y_osgb_fourier] = np.random.random(
(batch_size, n_y_osgb, n_fourier_features)
)
batch[BatchKey.nwp_target_time_utc] = np.random.random(
(batch_size, n_times, n_fourier_features)
)
batch[BatchKey.nwp_init_time_utc] = np.random.random((batch_size, n_times, n_fourier_features))
# fourier components
source_batch[NWPBatchKey.nwp_x_osgb_fourier] = np.random.random(
(batch_size, n_x_osgb, n_fourier_features)
)
source_batch[NWPBatchKey.nwp_y_osgb_fourier] = np.random.random(
(batch_size, n_y_osgb, n_fourier_features)
)
source_batch[NWPBatchKey.nwp_target_time_utc] = np.random.random(
(batch_size, n_times, n_fourier_features)
)
source_batch[NWPBatchKey.nwp_init_time_utc] = np.random.random(
(batch_size, n_times, n_fourier_features)
)

batch[nwp_source] = source_batch

return batch
return {BatchKey.nwp:batch}
6 changes: 4 additions & 2 deletions ocf_datapipes/batch/fake/pv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from ocf_datapipes.utils.consts import BatchKey


def make_fake_pv_data(configuration: Configuration, t0_datetime_utc: datetime):
def make_fake_pv_data( configuration: Configuration,
t0_datetime_utc: datetime,
batch_size: int = 8
) -> dict:
"""
Make Fake PV data ready for ML model
Expand All @@ -23,7 +26,6 @@ def make_fake_pv_data(configuration: Configuration, t0_datetime_utc: datetime):
if pv_config is None:
return {}

batch_size = configuration.process.batch_size
n_pv_systems = pv_config.n_pv_systems_per_example
n_fourier_features = 8

Expand Down
9 changes: 6 additions & 3 deletions ocf_datapipes/batch/fake/satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@


def make_fake_satellite_data(
configuration: Configuration, t0_datetime_utc: datetime, is_hrv: bool = False
):
configuration: Configuration,
t0_datetime_utc: datetime,
is_hrv: bool = False,
batch_size: int = 8,
) -> dict:

"""
Make Fake Satellite data ready for ML model. This makes data across all different data inputs
Expand All @@ -34,7 +38,6 @@ def make_fake_satellite_data(
if satellite_config is None:
return {}

batch_size = configuration.process.batch_size
n_channels = len(getattr(satellite_config, f"{variable}_channels"))
height = getattr(satellite_config, f"{variable}_image_size_pixels_height")
width = getattr(satellite_config, f"{variable}_image_size_pixels_width")
Expand Down
24 changes: 14 additions & 10 deletions ocf_datapipes/batch/fake/sun.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from ocf_datapipes.batch.fake.utils import get_n_time_steps_from_config
from ocf_datapipes.config.model import Configuration
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.utils.consts import BatchKey, NWPBatchKey


def make_fake_sun_data(configuration: Configuration):
def make_fake_sun_data(configuration: Configuration, batch_size: int = 8):
"""
Make Fake Sun data ready for ML model. This makes data across all different data inputs
Expand All @@ -18,7 +18,6 @@ def make_fake_sun_data(configuration: Configuration):
"""

batch = {}
batch_size = configuration.process.batch_size

# HRV Satellite
if configuration.input_data.hrvsatellite is not None:
Expand Down Expand Up @@ -56,12 +55,17 @@ def make_fake_sun_data(configuration: Configuration):

# NWP
if configuration.input_data.nwp is not None:
n_nwp_timesteps = get_n_time_steps_from_config(configuration.input_data.nwp)
batch[BatchKey.nwp_target_time_solar_azimuth] = np.random.random(
(batch_size, n_nwp_timesteps)
)
batch[BatchKey.nwp_target_time_solar_elevation] = np.random.random(
(batch_size, n_nwp_timesteps)
)
batch[BatchKey.nwp] = {}

for nwp_source, nwp_config in configuration.input_data.nwp.items():
batch[BatchKey.nwp][nwp_source] = {}

n_nwp_timesteps = get_n_time_steps_from_config(configuration.input_data.nwp[nwp_source])
batch[BatchKey.nwp][nwp_source][NWPBatchKey.nwp_target_time_solar_azimuth] = (
np.random.random((batch_size, n_nwp_timesteps))
)
batch[BatchKey.nwp][nwp_source][NWPBatchKey.nwp_target_time_solar_elevation] = (
np.random.random((batch_size, n_nwp_timesteps))
)

return batch
Loading

0 comments on commit 56b243d

Please sign in to comment.