Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mllam_data_prep/ops/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def check_step(sel_step, coord, ds):
"""
check that the step requested is exactly what the data has
"""
if len(ds[coord]) < 2:
raise ValueError(
f"Cannot compute step size for coordinate {coord} with fewer than 2 points"
)
all_steps = ds[coord].diff(dim=coord).values
first_step = all_steps[0].astype("timedelta64[s]").astype(datetime.timedelta)

Expand Down
72 changes: 72 additions & 0 deletions tests/test_chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Unit tests for ops.chunking module.
"""
import numpy as np
import pytest
import xarray as xr
from loguru import logger

from mllam_data_prep.ops.chunking import check_chunk_size, chunk_dataset


@pytest.fixture
def small_dataset():
"""Create a small test dataset."""
return xr.Dataset(
{
"var1": (["x", "y"], np.random.random((10, 10))),
"var2": (["x", "y"], np.random.random((10, 10))),
},
coords={"x": range(10), "y": range(10)},
)


def test_check_chunk_size_small_chunks(small_dataset):
"""Test check_chunk_size with small chunks (should not warn)."""
chunks = {"x": 5, "y": 5}
# Should not raise or warn
check_chunk_size(small_dataset, chunks)


def test_check_chunk_size_large_chunks(small_dataset):
"""Test check_chunk_size with large chunks (should warn)."""
# Use chunk sizes that exceed 1GB threshold
# For float64 (8 bytes), need chunks product > 1GB / 8 = 134217728
# Using chunks of 12000 x 12000 = 144000000 elements > 134217728
chunks = {"x": 12000, "y": 12000}

# Capture loguru logs using a handler
from io import StringIO

log_capture = StringIO()
handler_id = logger.add(log_capture, format="{message}")

try:
check_chunk_size(small_dataset, chunks)
log_output = log_capture.getvalue()
assert "exceeds" in log_output.lower()
finally:
logger.remove(handler_id)


def test_check_chunk_size_missing_dimension(small_dataset):
"""Test check_chunk_size when dimension doesn't exist in variable."""
chunks = {"x": 5, "z": 10} # z doesn't exist
# Should not raise, just skip the missing dimension
check_chunk_size(small_dataset, chunks)


def test_chunk_dataset_success(small_dataset):
"""Test chunk_dataset successfully chunks a dataset."""
chunks = {"x": 5, "y": 5}
chunked = chunk_dataset(small_dataset, chunks)
assert isinstance(chunked, xr.Dataset)
# Check that chunking was applied
assert chunked["var1"].chunks is not None


def test_chunk_dataset_invalid_chunks(small_dataset):
"""Test chunk_dataset with invalid chunk specification."""
chunks = {"x": -1} # Invalid chunk size
with pytest.raises(Exception, match="Error chunking dataset"):
chunk_dataset(small_dataset, chunks)
47 changes: 47 additions & 0 deletions tests/test_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Unit tests for ops.loading module.
"""
import pytest
import xarray as xr

from mllam_data_prep.ops.loading import load_input_dataset


@pytest.fixture
def sample_dataset():
"""Create a simple test dataset."""
return xr.Dataset(
{"var": (["x"], [1, 2, 3])},
coords={"x": [0, 1, 2]},
)


def test_load_input_dataset_zarr(sample_dataset, tmp_path):
"""Test load_input_dataset with zarr format."""
zarr_path = tmp_path / "test.zarr"
sample_dataset.to_zarr(zarr_path, mode="w")

loaded = load_input_dataset(str(zarr_path))
assert isinstance(loaded, xr.Dataset)
assert "var" in loaded.data_vars
assert list(loaded.x.values) == [0, 1, 2]


def test_load_input_dataset_netcdf(sample_dataset, tmp_path):
"""Test load_input_dataset with netCDF format."""
# Skip if NetCDF engine is not available
pytest.importorskip("netCDF4")

nc_path = tmp_path / "test.nc"
sample_dataset.to_netcdf(nc_path, engine="netcdf4")

loaded = load_input_dataset(str(nc_path))
assert isinstance(loaded, xr.Dataset)
assert "var" in loaded.data_vars
assert list(loaded.x.values) == [0, 1, 2]


def test_load_input_dataset_nonexistent():
"""Test load_input_dataset with non-existent file."""
with pytest.raises((OSError, FileNotFoundError)):
load_input_dataset("/nonexistent/path/to/file.zarr")
80 changes: 80 additions & 0 deletions tests/test_selection_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Unit tests for helper functions in ops.selection module.
"""
import pandas as pd
import pytest
import xarray as xr

from mllam_data_prep.ops.selection import check_point_in_dataset, check_step


@pytest.fixture
def simple_time_dataset():
"""Create a simple dataset with time coordinate."""
time_values = pd.date_range("2020-01-01", periods=5, freq="3H")
return xr.Dataset(
{"var": (["time"], range(5))},
coords={"time": time_values},
)


def test_check_point_in_dataset_point_exists(simple_time_dataset):
"""Test check_point_in_dataset when point exists in coordinate."""
point = simple_time_dataset.time.values[2]
# Should not raise
check_point_in_dataset("time", point, simple_time_dataset)


def test_check_point_in_dataset_point_not_exists(simple_time_dataset):
"""Test check_point_in_dataset when point does not exist in coordinate."""
point = pd.Timestamp("2020-01-02T12:00")
with pytest.raises(ValueError, match="Provided value for coordinate time"):
check_point_in_dataset("time", point, simple_time_dataset)


def test_check_point_in_dataset_none_point(simple_time_dataset):
"""Test check_point_in_dataset when point is None (should not raise)."""
# Should not raise when point is None
check_point_in_dataset("time", None, simple_time_dataset)


def test_check_step_constant_step_matches(simple_time_dataset):
"""Test check_step when step is constant and matches requested step."""
requested_step = pd.Timedelta(hours=3)
# Should not raise
check_step(requested_step, "time", simple_time_dataset)


def test_check_step_constant_step_mismatch(simple_time_dataset):
"""Test check_step when step is constant but doesn't match requested step."""
requested_step = pd.Timedelta(hours=6)
with pytest.raises(ValueError, match="Step size for coordinate time"):
check_step(requested_step, "time", simple_time_dataset)


def test_check_step_non_constant_step():
"""Test check_step when step size is not constant."""
# Create dataset with non-constant time steps
time_values = pd.to_datetime(
["2020-01-01T00:00", "2020-01-01T03:00", "2020-01-01T10:00", "2020-01-01T13:00"]
)
ds = xr.Dataset(
{"var": (["time"], range(4))},
coords={"time": time_values},
)
requested_step = pd.Timedelta(hours=3)
with pytest.raises(ValueError, match="Step size for coordinate time is not constant"):
check_step(requested_step, "time", ds)


def test_check_step_single_point_coordinate():
"""Test check_step with single point coordinate (should raise descriptive ValueError)."""
# Create dataset with single time point
time_values = pd.date_range("2020-01-01", periods=1, freq="3H")
ds = xr.Dataset(
{"var": (["time"], [1])},
coords={"time": time_values},
)
requested_step = pd.Timedelta(hours=3)
with pytest.raises(ValueError, match="Cannot compute step size.*fewer than 2 points"):
check_step(requested_step, "time", ds)