From 499747ef0b11e016675ce59f5a741b63c448b585 Mon Sep 17 00:00:00 2001 From: arnavsharma990 <2006arnavsharma@gmail.com> Date: Fri, 20 Feb 2026 13:40:28 +0530 Subject: [PATCH 1/2] tests: add unit tests for functions without direct test coverage Add focused unit tests for: - check_point_in_dataset: test point exists/not exists/None cases - check_step: test constant step matching/mismatching, non-constant step, edge cases - load_input_dataset: test zarr/netCDF loading and error handling - check_chunk_size: test warning behavior for small/large chunks, missing dimensions - chunk_dataset: test successful chunking and error handling These tests follow existing test patterns and provide minimal but complete coverage for previously untested helper functions. Co-authored-by: Cursor --- tests/test_chunking.py | 76 +++++++++++++++++++++++++++++ tests/test_loading.py | 47 ++++++++++++++++++ tests/test_selection_helpers.py | 84 +++++++++++++++++++++++++++++++++ 3 files changed, 207 insertions(+) create mode 100644 tests/test_chunking.py create mode 100644 tests/test_loading.py create mode 100644 tests/test_selection_helpers.py diff --git a/tests/test_chunking.py b/tests/test_chunking.py new file mode 100644 index 0000000..990cf5a --- /dev/null +++ b/tests/test_chunking.py @@ -0,0 +1,76 @@ +""" +Unit tests for ops.chunking module. +""" +import numpy as np +import pytest +import xarray as xr + +from mllam_data_prep.ops.chunking import check_chunk_size, chunk_dataset + + +@pytest.fixture +def small_dataset(): + """Create a small test dataset.""" + return xr.Dataset( + { + "var1": (["x", "y"], np.random.random((10, 10))), + "var2": (["x", "y"], np.random.random((10, 10))), + }, + coords={"x": range(10), "y": range(10)}, + ) + + +@pytest.fixture +def large_dataset(): + """Create a dataset that will exceed chunk size warning.""" + # Create dataset with large chunks that exceed 1GB warning + # Using float64 (8 bytes), need > 1GB / 8 = 134217728 elements + # For simplicity, create a smaller but still large dataset + size = 5000 + return xr.Dataset( + { + "large_var": (["x", "y"], np.random.random((size, size))), + }, + coords={"x": range(size), "y": range(size)}, + ) + + +def test_check_chunk_size_small_chunks(small_dataset, caplog): + """Test check_chunk_size with small chunks (should not warn).""" + chunks = {"x": 5, "y": 5} + check_chunk_size(small_dataset, chunks) + # Should not log any warnings + assert len(caplog.records) == 0 + + +def test_check_chunk_size_large_chunks(large_dataset, caplog): + """Test check_chunk_size with large chunks (should warn).""" + # Use chunks that will create large memory usage + chunks = {"x": 1000, "y": 1000} + check_chunk_size(large_dataset, chunks) + # Should log a warning + assert len(caplog.records) > 0 + assert "exceeds" in caplog.records[0].message.lower() + + +def test_check_chunk_size_missing_dimension(small_dataset): + """Test check_chunk_size when dimension doesn't exist in variable.""" + chunks = {"x": 5, "z": 10} # z doesn't exist + # Should not raise, just skip the missing dimension + check_chunk_size(small_dataset, chunks) + + +def test_chunk_dataset_success(small_dataset): + """Test chunk_dataset successfully chunks a dataset.""" + chunks = {"x": 5, "y": 5} + chunked = chunk_dataset(small_dataset, chunks) + assert isinstance(chunked, xr.Dataset) + # Check that chunking was applied + assert chunked["var1"].chunks is not None + + +def test_chunk_dataset_invalid_chunks(small_dataset): + """Test chunk_dataset with invalid chunk specification.""" + chunks = {"x": -1} # Invalid chunk size + with pytest.raises(Exception, match="Error chunking dataset"): + chunk_dataset(small_dataset, chunks) diff --git a/tests/test_loading.py b/tests/test_loading.py new file mode 100644 index 0000000..90e57c8 --- /dev/null +++ b/tests/test_loading.py @@ -0,0 +1,47 @@ +""" +Unit tests for ops.loading module. +""" +import tempfile +from pathlib import Path + +import pytest +import xarray as xr + +from mllam_data_prep.ops.loading import load_input_dataset + + +@pytest.fixture +def sample_dataset(): + """Create a simple test dataset.""" + return xr.Dataset( + {"var": (["x"], [1, 2, 3])}, + coords={"x": [0, 1, 2]}, + ) + + +def test_load_input_dataset_zarr(sample_dataset, tmp_path): + """Test load_input_dataset with zarr format.""" + zarr_path = tmp_path / "test.zarr" + sample_dataset.to_zarr(zarr_path, mode="w") + + loaded = load_input_dataset(str(zarr_path)) + assert isinstance(loaded, xr.Dataset) + assert "var" in loaded.data_vars + assert list(loaded.x.values) == [0, 1, 2] + + +def test_load_input_dataset_netcdf(sample_dataset, tmp_path): + """Test load_input_dataset with netCDF format.""" + nc_path = tmp_path / "test.nc" + sample_dataset.to_netcdf(nc_path) + + loaded = load_input_dataset(str(nc_path)) + assert isinstance(loaded, xr.Dataset) + assert "var" in loaded.data_vars + assert list(loaded.x.values) == [0, 1, 2] + + +def test_load_input_dataset_nonexistent(): + """Test load_input_dataset with non-existent file.""" + with pytest.raises((OSError, FileNotFoundError)): + load_input_dataset("/nonexistent/path/to/file.zarr") diff --git a/tests/test_selection_helpers.py b/tests/test_selection_helpers.py new file mode 100644 index 0000000..ac48587 --- /dev/null +++ b/tests/test_selection_helpers.py @@ -0,0 +1,84 @@ +""" +Unit tests for helper functions in ops.selection module. +""" +import datetime + +import pandas as pd +import pytest +import xarray as xr + +from mllam_data_prep.ops.selection import check_point_in_dataset, check_step + + +@pytest.fixture +def simple_time_dataset(): + """Create a simple dataset with time coordinate.""" + time_values = pd.date_range("2020-01-01", periods=5, freq="3H") + return xr.Dataset( + {"var": (["time"], range(5))}, + coords={"time": time_values}, + ) + + +def test_check_point_in_dataset_point_exists(simple_time_dataset): + """Test check_point_in_dataset when point exists in coordinate.""" + point = simple_time_dataset.time.values[2] + # Should not raise + check_point_in_dataset("time", point, simple_time_dataset) + + +def test_check_point_in_dataset_point_not_exists(simple_time_dataset): + """Test check_point_in_dataset when point does not exist in coordinate.""" + point = pd.Timestamp("2020-01-02T12:00") + with pytest.raises(ValueError, match="Provided value for coordinate time"): + check_point_in_dataset("time", point, simple_time_dataset) + + +def test_check_point_in_dataset_none_point(simple_time_dataset): + """Test check_point_in_dataset when point is None (should not raise).""" + # Should not raise when point is None + check_point_in_dataset("time", None, simple_time_dataset) + + +def test_check_step_constant_step_matches(simple_time_dataset): + """Test check_step when step is constant and matches requested step.""" + requested_step = pd.Timedelta(hours=3) + # Should not raise + check_step(requested_step, "time", simple_time_dataset) + + +def test_check_step_constant_step_mismatch(simple_time_dataset): + """Test check_step when step is constant but doesn't match requested step.""" + requested_step = pd.Timedelta(hours=6) + with pytest.raises(ValueError, match="Step size for coordinate time"): + check_step(requested_step, "time", simple_time_dataset) + + +def test_check_step_non_constant_step(): + """Test check_step when step size is not constant.""" + # Create dataset with non-constant time steps + time_values = pd.to_datetime( + ["2020-01-01T00:00", "2020-01-01T03:00", "2020-01-01T10:00", "2020-01-01T13:00"] + ) + ds = xr.Dataset( + {"var": (["time"], range(4))}, + coords={"time": time_values}, + ) + requested_step = pd.Timedelta(hours=3) + with pytest.raises(ValueError, match="Step size for coordinate time is not constant"): + check_step(requested_step, "time", ds) + + +def test_check_step_single_point_coordinate(): + """Test check_step with single point coordinate (edge case - will raise IndexError).""" + # Create dataset with single time point (diff will be empty array) + time_values = pd.date_range("2020-01-01", periods=1, freq="3H") + ds = xr.Dataset( + {"var": (["time"], [1])}, + coords={"time": time_values}, + ) + requested_step = pd.Timedelta(hours=3) + # This will raise IndexError when trying to access all_steps[0] on empty array + # This documents current behavior - could be improved to raise more descriptive error + with pytest.raises(IndexError): + check_step(requested_step, "time", ds) From f5cd1ab356c8fca5ea650454d4d8758eece4be0a Mon Sep 17 00:00:00 2001 From: arnavsharma990 <2006arnavsharma@gmail.com> Date: Fri, 20 Feb 2026 13:48:20 +0530 Subject: [PATCH 2/2] fix: address Copilot suggestions for test improvements - test_chunking.py: - Remove large_dataset fixture (inefficient 5000x5000 allocation) - Fix test_check_chunk_size_large_chunks to use chunk sizes that actually exceed 1GB threshold (12000x12000) instead of 1000x1000 - Use loguru handler to capture logs instead of caplog (which doesn't capture loguru output) - Remove caplog from test_check_chunk_size_small_chunks - test_loading.py: - Remove unused imports: tempfile and Path - Fix test_load_input_dataset_netcdf to use pytest.importorskip for netCDF4 engine and specify engine explicitly - test_selection_helpers.py: - Remove unused datetime import - Update test_check_step_single_point_coordinate to expect ValueError instead of IndexError - ops/selection.py: - Fix check_step to raise descriptive ValueError when coordinate has fewer than 2 points, instead of allowing IndexError Co-authored-by: Cursor --- mllam_data_prep/ops/selection.py | 4 +++ tests/test_chunking.py | 46 +++++++++++++++----------------- tests/test_loading.py | 8 +++--- tests/test_selection_helpers.py | 10 +++---- 4 files changed, 32 insertions(+), 36 deletions(-) diff --git a/mllam_data_prep/ops/selection.py b/mllam_data_prep/ops/selection.py index 37b91c1..fd3cb85 100644 --- a/mllam_data_prep/ops/selection.py +++ b/mllam_data_prep/ops/selection.py @@ -106,6 +106,10 @@ def check_step(sel_step, coord, ds): """ check that the step requested is exactly what the data has """ + if len(ds[coord]) < 2: + raise ValueError( + f"Cannot compute step size for coordinate {coord} with fewer than 2 points" + ) all_steps = ds[coord].diff(dim=coord).values first_step = all_steps[0].astype("timedelta64[s]").astype(datetime.timedelta) diff --git a/tests/test_chunking.py b/tests/test_chunking.py index 990cf5a..392456b 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -4,6 +4,7 @@ import numpy as np import pytest import xarray as xr +from loguru import logger from mllam_data_prep.ops.chunking import check_chunk_size, chunk_dataset @@ -20,37 +21,32 @@ def small_dataset(): ) -@pytest.fixture -def large_dataset(): - """Create a dataset that will exceed chunk size warning.""" - # Create dataset with large chunks that exceed 1GB warning - # Using float64 (8 bytes), need > 1GB / 8 = 134217728 elements - # For simplicity, create a smaller but still large dataset - size = 5000 - return xr.Dataset( - { - "large_var": (["x", "y"], np.random.random((size, size))), - }, - coords={"x": range(size), "y": range(size)}, - ) - - -def test_check_chunk_size_small_chunks(small_dataset, caplog): +def test_check_chunk_size_small_chunks(small_dataset): """Test check_chunk_size with small chunks (should not warn).""" chunks = {"x": 5, "y": 5} + # Should not raise or warn check_chunk_size(small_dataset, chunks) - # Should not log any warnings - assert len(caplog.records) == 0 -def test_check_chunk_size_large_chunks(large_dataset, caplog): +def test_check_chunk_size_large_chunks(small_dataset): """Test check_chunk_size with large chunks (should warn).""" - # Use chunks that will create large memory usage - chunks = {"x": 1000, "y": 1000} - check_chunk_size(large_dataset, chunks) - # Should log a warning - assert len(caplog.records) > 0 - assert "exceeds" in caplog.records[0].message.lower() + # Use chunk sizes that exceed 1GB threshold + # For float64 (8 bytes), need chunks product > 1GB / 8 = 134217728 + # Using chunks of 12000 x 12000 = 144000000 elements > 134217728 + chunks = {"x": 12000, "y": 12000} + + # Capture loguru logs using a handler + from io import StringIO + + log_capture = StringIO() + handler_id = logger.add(log_capture, format="{message}") + + try: + check_chunk_size(small_dataset, chunks) + log_output = log_capture.getvalue() + assert "exceeds" in log_output.lower() + finally: + logger.remove(handler_id) def test_check_chunk_size_missing_dimension(small_dataset): diff --git a/tests/test_loading.py b/tests/test_loading.py index 90e57c8..e751c1b 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -1,9 +1,6 @@ """ Unit tests for ops.loading module. """ -import tempfile -from pathlib import Path - import pytest import xarray as xr @@ -32,8 +29,11 @@ def test_load_input_dataset_zarr(sample_dataset, tmp_path): def test_load_input_dataset_netcdf(sample_dataset, tmp_path): """Test load_input_dataset with netCDF format.""" + # Skip if NetCDF engine is not available + pytest.importorskip("netCDF4") + nc_path = tmp_path / "test.nc" - sample_dataset.to_netcdf(nc_path) + sample_dataset.to_netcdf(nc_path, engine="netcdf4") loaded = load_input_dataset(str(nc_path)) assert isinstance(loaded, xr.Dataset) diff --git a/tests/test_selection_helpers.py b/tests/test_selection_helpers.py index ac48587..bf93596 100644 --- a/tests/test_selection_helpers.py +++ b/tests/test_selection_helpers.py @@ -1,8 +1,6 @@ """ Unit tests for helper functions in ops.selection module. """ -import datetime - import pandas as pd import pytest import xarray as xr @@ -70,15 +68,13 @@ def test_check_step_non_constant_step(): def test_check_step_single_point_coordinate(): - """Test check_step with single point coordinate (edge case - will raise IndexError).""" - # Create dataset with single time point (diff will be empty array) + """Test check_step with single point coordinate (should raise descriptive ValueError).""" + # Create dataset with single time point time_values = pd.date_range("2020-01-01", periods=1, freq="3H") ds = xr.Dataset( {"var": (["time"], [1])}, coords={"time": time_values}, ) requested_step = pd.Timedelta(hours=3) - # This will raise IndexError when trying to access all_steps[0] on empty array - # This documents current behavior - could be improved to raise more descriptive error - with pytest.raises(IndexError): + with pytest.raises(ValueError, match="Cannot compute step size.*fewer than 2 points"): check_step(requested_step, "time", ds)