diff --git a/mllam_data_prep/ops/selection.py b/mllam_data_prep/ops/selection.py index 37b91c1..fd3cb85 100644 --- a/mllam_data_prep/ops/selection.py +++ b/mllam_data_prep/ops/selection.py @@ -106,6 +106,10 @@ def check_step(sel_step, coord, ds): """ check that the step requested is exactly what the data has """ + if len(ds[coord]) < 2: + raise ValueError( + f"Cannot compute step size for coordinate {coord} with fewer than 2 points" + ) all_steps = ds[coord].diff(dim=coord).values first_step = all_steps[0].astype("timedelta64[s]").astype(datetime.timedelta) diff --git a/tests/test_chunking.py b/tests/test_chunking.py new file mode 100644 index 0000000..392456b --- /dev/null +++ b/tests/test_chunking.py @@ -0,0 +1,72 @@ +""" +Unit tests for ops.chunking module. +""" +import numpy as np +import pytest +import xarray as xr +from loguru import logger + +from mllam_data_prep.ops.chunking import check_chunk_size, chunk_dataset + + +@pytest.fixture +def small_dataset(): + """Create a small test dataset.""" + return xr.Dataset( + { + "var1": (["x", "y"], np.random.random((10, 10))), + "var2": (["x", "y"], np.random.random((10, 10))), + }, + coords={"x": range(10), "y": range(10)}, + ) + + +def test_check_chunk_size_small_chunks(small_dataset): + """Test check_chunk_size with small chunks (should not warn).""" + chunks = {"x": 5, "y": 5} + # Should not raise or warn + check_chunk_size(small_dataset, chunks) + + +def test_check_chunk_size_large_chunks(small_dataset): + """Test check_chunk_size with large chunks (should warn).""" + # Use chunk sizes that exceed 1GB threshold + # For float64 (8 bytes), need chunks product > 1GB / 8 = 134217728 + # Using chunks of 12000 x 12000 = 144000000 elements > 134217728 + chunks = {"x": 12000, "y": 12000} + + # Capture loguru logs using a handler + from io import StringIO + + log_capture = StringIO() + handler_id = logger.add(log_capture, format="{message}") + + try: + check_chunk_size(small_dataset, chunks) + log_output = log_capture.getvalue() + assert "exceeds" in log_output.lower() + finally: + logger.remove(handler_id) + + +def test_check_chunk_size_missing_dimension(small_dataset): + """Test check_chunk_size when dimension doesn't exist in variable.""" + chunks = {"x": 5, "z": 10} # z doesn't exist + # Should not raise, just skip the missing dimension + check_chunk_size(small_dataset, chunks) + + +def test_chunk_dataset_success(small_dataset): + """Test chunk_dataset successfully chunks a dataset.""" + chunks = {"x": 5, "y": 5} + chunked = chunk_dataset(small_dataset, chunks) + assert isinstance(chunked, xr.Dataset) + # Check that chunking was applied + assert chunked["var1"].chunks is not None + + +def test_chunk_dataset_invalid_chunks(small_dataset): + """Test chunk_dataset with invalid chunk specification.""" + chunks = {"x": -1} # Invalid chunk size + with pytest.raises(Exception, match="Error chunking dataset"): + chunk_dataset(small_dataset, chunks) diff --git a/tests/test_loading.py b/tests/test_loading.py new file mode 100644 index 0000000..e751c1b --- /dev/null +++ b/tests/test_loading.py @@ -0,0 +1,47 @@ +""" +Unit tests for ops.loading module. +""" +import pytest +import xarray as xr + +from mllam_data_prep.ops.loading import load_input_dataset + + +@pytest.fixture +def sample_dataset(): + """Create a simple test dataset.""" + return xr.Dataset( + {"var": (["x"], [1, 2, 3])}, + coords={"x": [0, 1, 2]}, + ) + + +def test_load_input_dataset_zarr(sample_dataset, tmp_path): + """Test load_input_dataset with zarr format.""" + zarr_path = tmp_path / "test.zarr" + sample_dataset.to_zarr(zarr_path, mode="w") + + loaded = load_input_dataset(str(zarr_path)) + assert isinstance(loaded, xr.Dataset) + assert "var" in loaded.data_vars + assert list(loaded.x.values) == [0, 1, 2] + + +def test_load_input_dataset_netcdf(sample_dataset, tmp_path): + """Test load_input_dataset with netCDF format.""" + # Skip if NetCDF engine is not available + pytest.importorskip("netCDF4") + + nc_path = tmp_path / "test.nc" + sample_dataset.to_netcdf(nc_path, engine="netcdf4") + + loaded = load_input_dataset(str(nc_path)) + assert isinstance(loaded, xr.Dataset) + assert "var" in loaded.data_vars + assert list(loaded.x.values) == [0, 1, 2] + + +def test_load_input_dataset_nonexistent(): + """Test load_input_dataset with non-existent file.""" + with pytest.raises((OSError, FileNotFoundError)): + load_input_dataset("/nonexistent/path/to/file.zarr") diff --git a/tests/test_selection_helpers.py b/tests/test_selection_helpers.py new file mode 100644 index 0000000..bf93596 --- /dev/null +++ b/tests/test_selection_helpers.py @@ -0,0 +1,80 @@ +""" +Unit tests for helper functions in ops.selection module. +""" +import pandas as pd +import pytest +import xarray as xr + +from mllam_data_prep.ops.selection import check_point_in_dataset, check_step + + +@pytest.fixture +def simple_time_dataset(): + """Create a simple dataset with time coordinate.""" + time_values = pd.date_range("2020-01-01", periods=5, freq="3H") + return xr.Dataset( + {"var": (["time"], range(5))}, + coords={"time": time_values}, + ) + + +def test_check_point_in_dataset_point_exists(simple_time_dataset): + """Test check_point_in_dataset when point exists in coordinate.""" + point = simple_time_dataset.time.values[2] + # Should not raise + check_point_in_dataset("time", point, simple_time_dataset) + + +def test_check_point_in_dataset_point_not_exists(simple_time_dataset): + """Test check_point_in_dataset when point does not exist in coordinate.""" + point = pd.Timestamp("2020-01-02T12:00") + with pytest.raises(ValueError, match="Provided value for coordinate time"): + check_point_in_dataset("time", point, simple_time_dataset) + + +def test_check_point_in_dataset_none_point(simple_time_dataset): + """Test check_point_in_dataset when point is None (should not raise).""" + # Should not raise when point is None + check_point_in_dataset("time", None, simple_time_dataset) + + +def test_check_step_constant_step_matches(simple_time_dataset): + """Test check_step when step is constant and matches requested step.""" + requested_step = pd.Timedelta(hours=3) + # Should not raise + check_step(requested_step, "time", simple_time_dataset) + + +def test_check_step_constant_step_mismatch(simple_time_dataset): + """Test check_step when step is constant but doesn't match requested step.""" + requested_step = pd.Timedelta(hours=6) + with pytest.raises(ValueError, match="Step size for coordinate time"): + check_step(requested_step, "time", simple_time_dataset) + + +def test_check_step_non_constant_step(): + """Test check_step when step size is not constant.""" + # Create dataset with non-constant time steps + time_values = pd.to_datetime( + ["2020-01-01T00:00", "2020-01-01T03:00", "2020-01-01T10:00", "2020-01-01T13:00"] + ) + ds = xr.Dataset( + {"var": (["time"], range(4))}, + coords={"time": time_values}, + ) + requested_step = pd.Timedelta(hours=3) + with pytest.raises(ValueError, match="Step size for coordinate time is not constant"): + check_step(requested_step, "time", ds) + + +def test_check_step_single_point_coordinate(): + """Test check_step with single point coordinate (should raise descriptive ValueError).""" + # Create dataset with single time point + time_values = pd.date_range("2020-01-01", periods=1, freq="3H") + ds = xr.Dataset( + {"var": (["time"], [1])}, + coords={"time": time_values}, + ) + requested_step = pd.Timedelta(hours=3) + with pytest.raises(ValueError, match="Cannot compute step size.*fewer than 2 points"): + check_step(requested_step, "time", ds)