Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add test for checking physical limits and zeroes in NWP data #… #340

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
3ee287c
chore: Add test for checking physical limits and zeroes in NWP data #…
glitch401 Jul 3, 2024
1e2df80
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2024
8105b91
changes to generate test data on the go. remove unnecessary zarr file…
glitch401 Jul 4, 2024
1eafe49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2024
d5bc6cf
Fix ValueError message for NWP data containing zeros and outside phys…
glitch401 Jul 4, 2024
d8cfa9d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2024
5e68173
Fix ValueError message coding style
glitch401 Jul 4, 2024
466b710
update physical limits in according to pvnet_uk_region/data_config.yaml
glitch401 Jul 5, 2024
692500c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2024
0667bab
Update temperature physical limits in OpenNWPIterDataPipe
glitch401 Jul 5, 2024
246d898
Fix NaN check in stack_np_examples_into_batch function
glitch401 Jul 11, 2024
55627eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2024
7ba254d
changes made to adapt for lazy loading
glitch401 Jul 16, 2024
c6ee33d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
d0c4f6f
moved limits to a constant file
glitch401 Jul 24, 2024
19050c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2024
3fe89fc
Refactor test_merge_numpy_examples_to_batch.py and test_load_nwp.py t…
glitch401 Aug 15, 2024
ace0259
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion ocf_datapipes/load/nwp/nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from typing import Union

import dask
import dask.array
import numpy as np
import xarray as xr
from constants import NWP_LIMITS
from ocf_blosc2 import Blosc2 # noqa: F401
from torch.utils.data import IterDataPipe, functional_datapipe

Expand All @@ -17,6 +18,8 @@
from ocf_datapipes.load.nwp.providers.merra2 import open_merra2
from ocf_datapipes.load.nwp.providers.ukv import open_ukv

from .constants import NWP_LIMITS

logger = logging.getLogger(__name__)


Expand All @@ -30,6 +33,7 @@ def __init__(
provider: str = "ukv",
check_for_zeros: bool = False,
check_physical_limits: bool = False,
check_for_nans: bool = False,
):
"""
Opens NWP Zarr and yields it
Expand All @@ -39,10 +43,12 @@ def __init__(
provider: NWP provider
check_for_zeros: Check for zeros in the NWP data
check_physical_limits: Check the physical limits of nwp data (e.g. -100<temperature<100)
check_for_nans: Check for NaNs in the NWP data
"""
self.zarr_path = zarr_path
self.check_for_zeros = check_for_zeros
self.check_physical_limits = check_physical_limits
self.check_for_nans = check_for_nans
self.limits = NWP_LIMITS
glitch401 marked this conversation as resolved.
Show resolved Hide resolved

logger.info(f"Using {provider.lower()}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very much just a suggestion, but it would be nice to have some control over which variables receive the checks. Intuitively, that should probably be possible by just passing a list of keys to be checked instead of True to check_for_zeroes/check_physical_limits

Expand Down Expand Up @@ -71,6 +77,8 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]: # type: ignore
self.check_if_zeros(nwp)
if self.check_physical_limits:
self.check_if_physical_limits(nwp)
if self.check_for_nans:
self.check_if_nans(nwp)
while True:
yield nwp

Expand Down Expand Up @@ -124,3 +132,20 @@ def check_if_physical_limits(self, nwp: Union[xr.DataArray, xr.Dataset]):
raise ValueError(
f"NWP data {var_name} is outside physical limits: ({lower},{upper})"
)
def check_if_nans(self, nwp: Union[xr.DataArray, xr.Dataset]):
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
"""Checks if the NWP data contains NaNs"""
if isinstance(nwp, xr.DataArray):
if dask.is_dask_collection(nwp.data):
if dask.array.isnan(nwp.data).any().compute():
raise ValueError("NWP data contains NaNs")
else:
if np.isnan(nwp.data).any():
raise ValueError("NWP DataArray contains NaNs")
elif isinstance(nwp, xr.Dataset):
for var in nwp.data_vars:
if dask.is_dask_collection(nwp[var].data):
if dask.array.isnan(nwp[var].data).any().compute():
raise ValueError(f"NWP Dataset variable{var} contains NaNs")
else:
if np.isnan(nwp[var].data).any():
raise ValueError(f"NWP Dataset variable{var} contains NaNs")
43 changes: 0 additions & 43 deletions tests/batch/test_merge_numpy_examples_to_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,49 +40,13 @@ def _single_batch_sample(fill_value):
return sample


def _single_batch_sample_nan(fill_value):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed all the functions related to checking nans

"""This function allows us to create batches with different filled values"""

sample: NumpyBatch = {}
sample[BatchKey.satellite_actual] = np.full(
(12, 10, 24, 24), fill_value, dtype=np.float32
) # shape: (time, channel, x, y)
sample[BatchKey.gsp_id] = np.full((1,), fill_value) # shape: (1,)
sample[BatchKey.gsp_t0_idx] = 4 # scalar and constant across all samples

sample_nwp_ukv: NWPNumpyBatch = {}
sample_nwp_ukv[NWPBatchKey.nwp] = np.full(
(8, 2, 24, 24), fill_value, dtype=np.float32
) # shape: (time, variable, x, y)
sample_nwp_ukv[NWPBatchKey.nwp][0, 0, 0, 0] = np.nan

sample_nwp_ukv[NWPBatchKey.nwp_channel_names] = ["a", "b"] # shape: (variable,)

sample_nwp_ecmwf: NWPNumpyBatch = {}
sample_nwp_ecmwf[NWPBatchKey.nwp] = np.full(
(8, 4, 12, 12), fill_value
) # shape: (time, variable, x, y)

sample[BatchKey.nwp] = {
"ukv": sample_nwp_ukv,
"ecmwf": sample_nwp_ecmwf,
}
# print(sample[BatchKey.nwp]["ukv"])
return sample


@pytest.fixture
def numpy_sample_datapipe():
dp = IterableWrapper([_single_batch_sample(i) for i in range(8)])
return dp


@pytest.fixture
def numpy_nan_sample_datapipe():
dp = IterableWrapper([_single_batch_sample_nan(i) for i in range(8)])
return dp


def test_merge_numpy_batch(numpy_sample_datapipe):
dp = MergeNumpyBatchIterDataPipe(numpy_sample_datapipe.batch(4))
dp_iter = iter(dp)
Expand All @@ -99,13 +63,6 @@ def test_merge_numpy_batch(numpy_sample_datapipe):
assert nwp_batch[NWPBatchKey.nwp_channel_names] == ["a", "b"]


def test_merge_numpy_batch_for_nans(numpy_nan_sample_datapipe):
with pytest.raises(
ValueError
): # checks for Error raised if NWP/BatchKey DataArray contains Nans
dp = MergeNumpyBatchIterDataPipe(numpy_nan_sample_datapipe.batch(4))
dp_iter = iter(dp)
metadata = next(dp_iter)


def test_merge_numpy_examples_to_batch(numpy_sample_datapipe):
Expand Down
25 changes: 20 additions & 5 deletions tests/load/nwp/test_load_nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_check_for_zeros():
# to generate data with zeros and limits:
original_store_path = "tests/data/nwp_data/test.zarr"
original_store = zarr.open(original_store_path, mode="r")
new_store_path = "tests/data/nwp_data/test_with_zeros_n_limits.zarr"
new_store_path = "tests/data/nwp_data/test_with_zeros_n_limits_n_nans.zarr"
# Optionally, clear the destination store if it already exists
shutil.rmtree(new_store_path, ignore_errors=True)
with zarr.open(new_store_path, mode="w") as new_store:
Expand All @@ -120,9 +120,11 @@ def test_check_for_zeros():

new_store["UKV"][0, 0, 0, 0] = 0
new_store["UKV"][0, 0, 0, 1] = np.random.uniform(190, 360, size=(548,))
new_store["UKV"][0, 0, 0, 2] = np.nan

shutil.copy(
"tests/data/nwp_data/test.zarr/.zmetadata",
"tests/data/nwp_data/test_with_zeros_n_limits.zarr/.zmetadata",
"tests/data/nwp_data/test_with_zeros_n_limits_n_nans.zarr/.zmetadata",
)

# positive test case
Expand All @@ -142,7 +144,7 @@ def test_check_for_zeros():
def test_check_physical_limits():
# positive test case
nwp_datapipe1 = OpenNWP(
zarr_path="tests/data/nwp_data/test_with_zeros_n_limits.zarr", check_physical_limits=True
zarr_path="tests/data/nwp_data/test_with_zeros_n_limits_n_nans.zarr", check_physical_limits=True
)
with pytest.raises(
ValueError
Expand All @@ -154,6 +156,19 @@ def test_check_physical_limits():
metadata = next(iter(nwp_datapipe2))
assert metadata is not None

def test_check_if_nans():
# positive test case
nwp_datapipe1 = OpenNWP(
zarr_path="tests/data/nwp_data/test_with_zeros_n_limits_n_nans.zarr", check_for_nans=True
)
with pytest.raises(ValueError): # checks for Error raised if NWP DataArray contains nans
metadata = next(iter(nwp_datapipe1))

# negative test case
nwp_datapipe2 = OpenNWP(zarr_path="tests/data/nwp_data/test.zarr", check_for_nans=True)
metadata = next(iter(nwp_datapipe2))
assert metadata is not None

shutil.rmtree(
"tests/data/nwp_data/test_with_zeros_n_limits.zarr"
) # removes the zarr file created for testing
"tests/data/nwp_data/test_with_zeros_n_limits_n_nans.zarr"
) # removes the zarr file created for testing